OwlMaster commited on
Commit
c19926b
1 Parent(s): df106cf

Update util.py

Browse files
Files changed (1) hide show
  1. util.py +631 -565
util.py CHANGED
@@ -1,565 +1,631 @@
1
- from typing import Optional, Tuple, Union
2
-
3
- import torch
4
- from einops import rearrange, repeat
5
- import torch.nn.functional as F
6
-
7
- #import triton
8
- #import triton.language as tl
9
-
10
-
11
- # @triton.autotune(
12
- # configs=[
13
- # triton.Config({"BLOCK_M": 2}),
14
- # triton.Config({"BLOCK_M": 4}),
15
- # triton.Config({"BLOCK_M": 8}),
16
- # triton.Config({"BLOCK_M": 16}),
17
- # ],
18
- # key=["CACHE_KEY_SEQLEN", "BLOCK_K", "INTERLEAVED"],
19
- # )
20
- #@triton.jit
21
- # def rotary_kernel(
22
- # OUT, # Pointers to matrices
23
- # X,
24
- # COS,
25
- # SIN,
26
- # CU_SEQLENS,
27
- # SEQLEN_OFFSETS, # this could be int or a pointer
28
- # # Matrix dimensions
29
- # seqlen,
30
- # nheads,
31
- # rotary_dim,
32
- # seqlen_ro,
33
- # CACHE_KEY_SEQLEN,
34
- # # strides
35
- # stride_out_batch,
36
- # stride_out_nheads,
37
- # stride_out_seqlen,
38
- # stride_out_headdim,
39
- # stride_x_batch,
40
- # stride_x_nheads,
41
- # stride_x_seqlen,
42
- # stride_x_headdim,
43
- # # Meta-parameters
44
- # BLOCK_K: tl.constexpr,
45
- # IS_SEQLEN_OFFSETS_TENSOR: tl.constexpr,
46
- # IS_VARLEN: tl.constexpr,
47
- # INTERLEAVED: tl.constexpr,
48
- # CONJUGATE: tl.constexpr,
49
- # BLOCK_M: tl.constexpr,
50
- # ):
51
- # pid_m = tl.program_id(axis=0)
52
- # pid_batch = tl.program_id(axis=1)
53
- # pid_head = tl.program_id(axis=2)
54
- # rotary_dim_half = rotary_dim // 2
55
-
56
- # if not IS_VARLEN:
57
- # X = X + pid_batch * stride_x_batch + pid_head * stride_x_nheads
58
- # OUT = OUT + pid_batch * stride_out_batch + pid_head * stride_out_nheads
59
- # COS = COS + pid_batch * seqlen_ro * rotary_dim_half
60
- # SIN = SIN + pid_batch * seqlen_ro * rotary_dim_half
61
- # else:
62
- # start_idx = tl.load(CU_SEQLENS + pid_batch)
63
- # seqlen = tl.load(CU_SEQLENS + pid_batch + 1) - start_idx
64
- # X = X + start_idx * stride_x_seqlen + pid_head * stride_x_nheads
65
- # OUT = OUT + start_idx * stride_out_seqlen + pid_head * stride_out_nheads
66
-
67
- # if pid_m * BLOCK_M >= seqlen:
68
- # return
69
- # rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
70
- # if not IS_SEQLEN_OFFSETS_TENSOR:
71
- # rm_cs = rm + SEQLEN_OFFSETS
72
- # else:
73
- # rm_cs = rm + tl.load(SEQLEN_OFFSETS + pid_batch)
74
- # rk = tl.arange(0, BLOCK_K)
75
- # rk_half = tl.arange(0, BLOCK_K // 2)
76
-
77
- # if not INTERLEAVED:
78
- # # Load the 1st and 2nd halves of X, do calculation, then store to 1st and 2nd halves of OUT
79
- # X = X + (rm[:, None] * stride_x_seqlen + rk_half[None, :] * stride_x_headdim)
80
- # COS = COS + (rm_cs[:, None] * rotary_dim_half + rk_half[None, :])
81
- # SIN = SIN + (rm_cs[:, None] * rotary_dim_half + rk_half[None, :])
82
- # cos = tl.load(
83
- # COS, mask=(rm_cs[:, None] < seqlen_ro) & (rk_half[None, :] < rotary_dim_half), other=1.0
84
- # )
85
- # sin = tl.load(
86
- # SIN, mask=(rm_cs[:, None] < seqlen_ro) & (rk_half[None, :] < rotary_dim_half), other=0.0
87
- # )
88
- # x0 = tl.load(
89
- # X, mask=(rm[:, None] < seqlen) & (rk_half[None, :] < rotary_dim_half), other=0.0
90
- # )
91
- # x1 = tl.load(
92
- # X + rotary_dim_half * stride_x_headdim,
93
- # mask=(rm[:, None] < seqlen) & (rk_half[None, :] < rotary_dim_half),
94
- # other=0.0,
95
- # )
96
- # if CONJUGATE:
97
- # sin = -sin
98
- # o0 = x0 * cos - x1 * sin
99
- # o1 = x0 * sin + x1 * cos
100
- # # write back result
101
- # OUT = OUT + (rm[:, None] * stride_out_seqlen + rk_half[None, :] * stride_out_headdim)
102
- # tl.store(OUT, o0, mask=(rm[:, None] < seqlen) & (rk_half[None, :] < rotary_dim_half))
103
- # tl.store(
104
- # OUT + rotary_dim_half * stride_out_headdim,
105
- # o1,
106
- # mask=(rm[:, None] < seqlen) & (rk_half[None, :] < rotary_dim_half),
107
- # )
108
- # else:
109
- # # We don't want to load X[0, 2, 4, ...] and X[1, 3, 5, ...] separately since both are slow.
110
- # # Instead, we load x0 = X[0, 1, 2, 3, ...] and x1 = X[1, 0, 3, 2, ...].
111
- # # Loading x0 will be fast but x1 will be slow.
112
- # # Then we load cos = COS[0, 0, 1, 1, ...] and sin = SIN[0, 0, 1, 1, ...].
113
- # # Then we do the calculation and use tl.where to pick put the right outputs for the even
114
- # # and for the odd indices.
115
- # rk_swap = rk + ((rk + 1) % 2) * 2 - 1 # 1, 0, 3, 2, 5, 4, ...
116
- # rk_repeat = tl.arange(0, BLOCK_K) // 2
117
- # X0 = X + (rm[:, None] * stride_x_seqlen + rk[None, :] * stride_x_headdim)
118
- # X1 = X + (rm[:, None] * stride_x_seqlen + rk_swap[None, :] * stride_x_headdim)
119
- # COS = COS + (rm_cs[:, None] * rotary_dim_half + rk_repeat[None, :])
120
- # SIN = SIN + (rm_cs[:, None] * rotary_dim_half + rk_repeat[None, :])
121
- # cos = tl.load(
122
- # COS,
123
- # mask=(rm_cs[:, None] < seqlen_ro) & (rk_repeat[None, :] < rotary_dim_half),
124
- # other=1.0,
125
- # ).to(tl.float32)
126
- # sin = tl.load(
127
- # SIN,
128
- # mask=(rm_cs[:, None] < seqlen_ro) & (rk_repeat[None, :] < rotary_dim_half),
129
- # other=0.0,
130
- # ).to(tl.float32)
131
- # x0 = tl.load(X0, mask=(rm[:, None] < seqlen) & (rk[None, :] < rotary_dim), other=0.0).to(
132
- # tl.float32
133
- # )
134
- # x1 = tl.load(
135
- # X1, mask=(rm[:, None] < seqlen) & (rk_swap[None, :] < rotary_dim), other=0.0
136
- # ).to(tl.float32)
137
- # if CONJUGATE:
138
- # sin = -sin
139
- # x0_cos = x0 * cos
140
- # x1_sin = x1 * sin
141
- # out = tl.where(rk[None, :] % 2 == 0, x0_cos - x1_sin, x0_cos + x1_sin)
142
- # OUT = OUT + (rm[:, None] * stride_out_seqlen + rk[None, :] * stride_out_headdim)
143
- # tl.store(OUT, out, mask=(rm[:, None] < seqlen) & (rk[None, :] < rotary_dim))
144
-
145
-
146
- # def apply_rotary(
147
- # x: torch.Tensor,
148
- # cos: torch.Tensor,
149
- # sin: torch.Tensor,
150
- # seqlen_offsets: Union[int, torch.Tensor] = 0,
151
- # cu_seqlens: Optional[torch.Tensor] = None,
152
- # max_seqlen: Optional[int] = None,
153
- # interleaved=False,
154
- # inplace=False,
155
- # conjugate=False,
156
- # ) -> torch.Tensor:
157
- # """
158
- # Arguments:
159
- # x: (batch, seqlen, nheads, headdim) if cu_seqlens is None
160
- # else (total_seqlen, nheads, headdim).
161
- # cos: (seqlen_ro, rotary_dim / 2)
162
- # sin: (seqlen_ro, rotary_dim / 2)
163
- # seqlen_offsets: integer or integer tensor of size (batch,)
164
- # cu_seqlens: (batch + 1,) or None
165
- # max_seqlen: int
166
- # Returns:
167
- # y: (batch, seqlen, nheads, headdim)
168
- # """
169
-
170
- # batch, nheads, seqlen, headdim = x.shape
171
-
172
- # batch_ro, seqlen_ro, rotary_dim = cos.shape
173
-
174
- # assert batch == batch_ro
175
- # assert sin.shape == cos.shape
176
- # rotary_dim *= 2
177
- # assert rotary_dim <= headdim, "rotary_dim must be <= headdim"
178
- # assert headdim <= 256, "Only support headdim <= 256"
179
-
180
- # assert seqlen_ro >= seqlen, "seqlen_ro must be >= seqlen"
181
-
182
- # assert (
183
- # cos.dtype == sin.dtype
184
- # ), f"cos and sin must have the same dtype, got {cos.dtype} and {sin.dtype}"
185
- # assert (
186
- # x.dtype == cos.dtype
187
- # ), f"Input and cos/sin must have the same dtype, got {x.dtype} and {cos.dtype}"
188
-
189
- # cos, sin = cos.contiguous(), sin.contiguous()
190
- # if isinstance(seqlen_offsets, torch.Tensor):
191
- # assert seqlen_offsets.shape == (batch,)
192
- # assert seqlen_offsets.dtype in [torch.int32, torch.int64]
193
- # seqlen_offsets = seqlen_offsets.contiguous()
194
- # else:
195
- # assert seqlen_offsets + seqlen <= seqlen_ro
196
-
197
- # output = torch.empty_like(x) if not inplace else x
198
- # if rotary_dim < headdim and not inplace:
199
- # output[..., rotary_dim:].copy_(x[..., rotary_dim:])
200
-
201
- # BLOCK_K = (
202
- # 32
203
- # if rotary_dim <= 32
204
- # else (64 if rotary_dim <= 64 else (128 if rotary_dim <= 128 else 256))
205
- # )
206
- # grid = lambda META: (triton.cdiv(seqlen, META["BLOCK_M"]), batch, nheads) # noqa
207
- # BLOCK_M = 4 if interleaved else (8 if rotary_dim <= 64 else 4)
208
-
209
- # # Need this, otherwise Triton tries to launch from cuda:0 and we get
210
- # # ValueError: Pointer argument (at 0) cannot be accessed from Triton (cpu tensor?)
211
- # with torch.cuda.device(x.device.index):
212
- # rotary_kernel[grid](
213
- # output, # data ptrs
214
- # x,
215
- # cos,
216
- # sin,
217
- # cu_seqlens,
218
- # seqlen_offsets,
219
- # seqlen, # shapes
220
- # nheads,
221
- # rotary_dim,
222
- # seqlen_ro,
223
- # seqlen // 128, # key for triton cache (limit number of compilations)
224
- # output.stride(0), # batch_strides
225
- # output.stride(-3), # nheads_stride
226
- # output.stride(-2), # seqlen_stride
227
- # output.stride(-1), # headdim_stride
228
- # x.stride(0), # batch_strides
229
- # x.stride(-3), # nheads stride
230
- # x.stride(-2), # seqlen stride
231
- # x.stride(-1), # headdim stride
232
- # BLOCK_K,
233
- # isinstance(seqlen_offsets, torch.Tensor),
234
- # False,
235
- # interleaved,
236
- # conjugate,
237
- # BLOCK_M,
238
- # )
239
- # return output
240
- def apply_rotary(
241
- x: torch.Tensor,
242
- cos: torch.Tensor,
243
- sin: torch.Tensor,
244
- seqlen_offsets: Union[int, torch.Tensor] = 0,
245
- cu_seqlens: Optional[torch.Tensor] = None,
246
- max_seqlen: Optional[int] = None,
247
- interleaved=False,
248
- inplace=False,
249
- conjugate=False,
250
- ) -> torch.Tensor:
251
- """
252
- Arguments:
253
- x: (batch, seqlen, nheads, headdim) if cu_seqlens is None
254
- else (total_seqlen, nheads, headdim).
255
- cos: (seqlen_ro, rotary_dim / 2)
256
- sin: (seqlen_ro, rotary_dim / 2)
257
- seqlen_offsets: integer or integer tensor of size (batch,)
258
- cu_seqlens: (batch + 1,) or None
259
- max_seqlen: int
260
- Returns:
261
- y: (batch, seqlen, nheads, headdim)
262
- """
263
-
264
- batch, nheads, seqlen, headdim = x.shape
265
-
266
- batch_ro, seqlen_ro, rotary_dim = cos.shape
267
-
268
- assert batch == batch_ro
269
- assert sin.shape == cos.shape
270
- rotary_dim *= 2
271
- assert rotary_dim <= headdim, "rotary_dim must be <= headdim"
272
- assert headdim <= 256, "Only support headdim <= 256"
273
-
274
- assert seqlen_ro >= seqlen, "seqlen_ro must be >= seqlen"
275
-
276
- assert (
277
- cos.dtype == sin.dtype
278
- ), f"cos and sin must have the same dtype, got {cos.dtype} and {sin.dtype}"
279
- assert (
280
- x.dtype == cos.dtype
281
- ), f"Input and cos/sin must have the same dtype, got {x.dtype} and {cos.dtype}"
282
-
283
- cos, sin = cos.contiguous(), sin.contiguous()
284
- if isinstance(seqlen_offsets, torch.Tensor):
285
- assert seqlen_offsets.shape == (batch,)
286
- assert seqlen_offsets.dtype in [torch.int32, torch.int64]
287
- seqlen_offsets = seqlen_offsets.contiguous()
288
- else:
289
- assert seqlen_offsets + seqlen <= seqlen_ro
290
-
291
- output = torch.empty_like(x) if not inplace else x
292
- if rotary_dim < headdim and not inplace:
293
- output[..., rotary_dim:].copy_(x[..., rotary_dim:])
294
-
295
- rotary_dim_half = rotary_dim // 2
296
- for b in range(batch):
297
- for h in range(nheads):
298
- for s in range(seqlen):
299
- idx = s + seqlen_offsets if isinstance(seqlen_offsets, int) else s + seqlen_offsets[b]
300
- if idx >= seqlen_ro:
301
- continue
302
-
303
- cos_idx = cos[b, idx, :rotary_dim_half]
304
- sin_idx = sin[b, idx, :rotary_dim_half]
305
- if conjugate:
306
- sin_idx = -sin_idx
307
-
308
- if not interleaved:
309
- x0 = x[b, h, s, :rotary_dim_half]
310
- x1 = x[b, h, s, rotary_dim_half:rotary_dim]
311
- o0 = x0 * cos_idx - x1 * sin_idx
312
- o1 = x0 * sin_idx + x1 * cos_idx
313
- output[b, h, s, :rotary_dim_half] = o0
314
- output[b, h, s, rotary_dim_half:rotary_dim] = o1
315
- else:
316
- for i in range(rotary_dim):
317
- if i % 2 == 0:
318
- output[b, h, s, i] = x[b, h, s, i] * cos_idx[i // 2] - x[b, h, s, i + 1] * sin_idx[i // 2]
319
- else:
320
- output[b, h, s, i] = x[b, h, s, i - 1] * sin_idx[i // 2] + x[b, h, s, i] * cos_idx[i // 2]
321
-
322
- return output
323
-
324
-
325
- class ApplyRotaryEmb(torch.autograd.Function):
326
- @staticmethod
327
- def forward(
328
- ctx,
329
- x,
330
- cos,
331
- sin,
332
- interleaved=False,
333
- inplace=False,
334
- seqlen_offsets: Union[int, torch.Tensor] = 0,
335
- cu_seqlens: Optional[torch.Tensor] = None,
336
- max_seqlen: Optional[int] = None,
337
- ):
338
- out = apply_rotary(
339
- x,
340
- cos,
341
- sin,
342
- seqlen_offsets=seqlen_offsets,
343
- cu_seqlens=cu_seqlens,
344
- interleaved=interleaved,
345
- inplace=inplace,
346
- )
347
- if isinstance(seqlen_offsets, int):
348
- ctx.save_for_backward(cos, sin, cu_seqlens) # Can't save int with save_for_backward
349
- ctx.seqlen_offsets = seqlen_offsets
350
- else:
351
- ctx.save_for_backward(cos, sin, cu_seqlens, seqlen_offsets)
352
- ctx.seqlen_offsets = None
353
- ctx.interleaved = interleaved
354
- ctx.inplace = inplace
355
- ctx.max_seqlen = max_seqlen
356
- return out if not inplace else x
357
-
358
- @staticmethod
359
- def backward(ctx, do):
360
- seqlen_offsets = ctx.seqlen_offsets
361
- if seqlen_offsets is None:
362
- cos, sin, cu_seqlens, seqlen_offsets = ctx.saved_tensors
363
- else:
364
- cos, sin, cu_seqlens = ctx.saved_tensors
365
- # TD [2023-09-02]: For some reason Triton (2.0.0.post1) errors with
366
- # "[CUDA]: invalid device context", and cloning makes it work. Idk why. Triton 2.1.0 works.
367
- if not ctx.interleaved and not ctx.inplace:
368
- do = do.clone()
369
- dx = apply_rotary(
370
- do,
371
- cos,
372
- sin,
373
- seqlen_offsets=seqlen_offsets,
374
- cu_seqlens=cu_seqlens,
375
- max_seqlen=ctx.max_seqlen,
376
- interleaved=ctx.interleaved,
377
- inplace=ctx.inplace,
378
- conjugate=True,
379
- )
380
- return dx, None, None, None, None, None, None, None
381
-
382
-
383
- def apply_rotary_emb(
384
- x,
385
- cos,
386
- sin,
387
- interleaved=False,
388
- inplace=False,
389
- seqlen_offsets: Union[int, torch.Tensor] = 0,
390
- cu_seqlens: Optional[torch.Tensor] = None,
391
- max_seqlen: Optional[int] = None,
392
- ):
393
- """
394
- Arguments:
395
- x: (batch_size, seqlen, nheads, headdim) if cu_seqlens is None
396
- else (total_seqlen, nheads, headdim)
397
- cos, sin: (seqlen_rotary, rotary_dim / 2)
398
- interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead
399
- of 1st half and 2nd half (GPT-NeoX style).
400
- inplace: if True, apply rotary embedding in-place.
401
- seqlen_offsets: (batch_size,) or int. Each sequence in x is shifted by this amount.
402
- Most commonly used in inference when we have KV cache.
403
- cu_seqlens: (batch + 1,) or None
404
- max_seqlen: int
405
- Return:
406
- out: (batch_size, seqlen, nheads, headdim) if cu_seqlens is None
407
- else (total_seqlen, nheads, headdim)
408
- rotary_dim must be <= headdim
409
- Apply rotary embedding to the first rotary_dim of x.
410
- """
411
- return ApplyRotaryEmb.apply(
412
- x, cos, sin, interleaved, inplace, seqlen_offsets, cu_seqlens, max_seqlen
413
- )
414
-
415
-
416
- # For backward compatibility
417
- apply_rotary_emb_func = apply_rotary_emb
418
-
419
-
420
- class FastRotaryEmbedding(torch.nn.Module):
421
- """
422
- The rotary position embeddings from RoFormer_ (Su et. al).
423
- A crucial insight from the method is that the query and keys are
424
- transformed by rotation matrices which depend on the relative positions.
425
-
426
- Other implementations are available in the Rotary Transformer repo_ and in
427
- GPT-NeoX_, GPT-NeoX was an inspiration
428
-
429
- .. _RoFormer: https://arxiv.org/abs/2104.09864
430
- .. _repo: https://github.com/ZhuiyiTechnology/roformer
431
- .. _GPT-NeoX: https://github.com/EleutherAI/gpt-neox
432
-
433
- If scale_base is not None, this implements XPos (Sun et al., https://arxiv.org/abs/2212.10554).
434
- A recommended value for scale_base is 512: https://github.com/HazyResearch/flash-attention/issues/96
435
- Reference: https://github.com/sunyt32/torchscale/blob/main/torchscale/component/xpos_relative_position.py
436
- """
437
-
438
- def __init__(
439
- self,
440
- dim: int,
441
- base=10000,
442
- interleaved=False,
443
- scale_base=None,
444
- pos_idx_in_fp32=True,
445
- device=None,
446
- ):
447
- """
448
- interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead
449
- of 1st half and 2nd half (GPT-NeoX style).
450
- pos_idx_in_fp32: if True, the position indices [0.0, ..., seqlen - 1] are in fp32,
451
- otherwise they might be in lower precision.
452
- This option was added because previously (before 2023-07-02), when we construct
453
- the position indices, we use the dtype of self.inv_freq. In most cases this would
454
- be fp32, but if the model is trained in pure bf16 (not mixed precision), then
455
- self.inv_freq would be bf16, and the position indices are also in bf16.
456
- Because of the limited precision of bf16 (e.g. 1995.0 is rounded to 2000.0), the
457
- embeddings for some positions will coincide.
458
- To maintain compatibility with models previously trained in pure bf16,
459
- we add this option.
460
- """
461
- super().__init__()
462
- self.dim = dim
463
- self.base = base
464
- self.pos_idx_in_fp32 = pos_idx_in_fp32
465
- # Generate and save the inverse frequency buffer (non trainable)
466
- inv_freq = self._compute_inv_freq(device)
467
- self.register_buffer("inv_freq", inv_freq)
468
- self.interleaved = interleaved
469
- self.scale_base = scale_base
470
- scale = (
471
- (torch.arange(0, dim, 2, device=device, dtype=torch.float32) + 0.4 * dim) / (1.4 * dim)
472
- if scale_base is not None
473
- else None
474
- )
475
- self.register_buffer("scale", scale, persistent=False)
476
-
477
- self._seq_len_cached = 0
478
- self._cos_cached = None
479
- self._sin_cached = None
480
- self._cos_k_cached = None
481
- self._sin_k_cached = None
482
- self.cos = None
483
- self.sin = None
484
-
485
- def _compute_inv_freq(self, device=None):
486
- return 1.0 / (
487
- self.base
488
- ** (torch.arange(0, self.dim, 2, device=device) / self.dim)
489
- # ** (torch.arange(0, self.dim, 2, device=device).float() / self.dim)
490
- )
491
-
492
- def _update_cos_sin_cache(self, seqlen, position_id, device=None, dtype=None):
493
-
494
- if (
495
- seqlen > self._seq_len_cached
496
- ):
497
- self._seq_len_cached = seqlen
498
- # We want fp32 here, not self.inv_freq.dtype, since the model could be loaded in bf16
499
- # And the output of arange can be quite large, so bf16 would lose a lot of precision.
500
- # However, for compatibility reason, we add an option to use the dtype of self.inv_freq.
501
- if self.pos_idx_in_fp32:
502
- t = torch.arange(seqlen, device=device, dtype=torch.float32)
503
- # We want fp32 here as well since inv_freq will be multiplied with t, and the output
504
- # will be large. Having it in bf16 will lose a lot of precision and cause the
505
- # cos & sin output to change significantly.
506
- # We want to recompute self.inv_freq if it was not loaded in fp32
507
- if self.inv_freq.dtype != torch.float32:
508
- inv_freq = self._compute_inv_freq(device=device)
509
- else:
510
- inv_freq = self.inv_freq
511
- else:
512
- t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype)
513
- inv_freq = self.inv_freq
514
- freqs = torch.einsum("i,j->ij", t, inv_freq)
515
- if self.scale is None:
516
- self._cos_cached = torch.cos(freqs).to(dtype)
517
- self._sin_cached = torch.sin(freqs).to(dtype)
518
-
519
- else:
520
- power = (
521
- torch.arange(seqlen, dtype=self.scale.dtype, device=self.scale.device)
522
- - seqlen // 2
523
- ) / self.scale_base
524
- scale = self.scale.to(device=power.device) ** rearrange(power, "s -> s 1")
525
- # We want the multiplication by scale to happen in fp32
526
- self._cos_cached = (torch.cos(freqs) * scale).to(dtype)
527
- self._sin_cached = (torch.sin(freqs) * scale).to(dtype)
528
- self._cos_k_cached = (torch.cos(freqs) / scale).to(dtype)
529
- self._sin_k_cached = (torch.sin(freqs) / scale).to(dtype)
530
-
531
- def forward(
532
- self,
533
- q: torch.Tensor,
534
- k: torch.Tensor,
535
- position_ids: torch.Tensor,
536
- max_seqlen,
537
- ) -> Tuple[torch.Tensor, torch.Tensor]:
538
- """
539
- q: (batch, nheads, seqlen, headdim)
540
- k: (batch, nheads, seqlen, headdim)
541
- position_id: (batch, seqlen)
542
- max_seqlen: int
543
- layer_id: int
544
- only if layer_id == 0, then update cons and sin
545
- Apply rotary embedding *inplace* to q k.
546
- """
547
-
548
- self._update_cos_sin_cache(max_seqlen, position_ids, device=q.device, dtype=q.dtype)
549
- cos, sin = F.embedding(position_ids, self._cos_cached), F.embedding(position_ids, self._sin_cached)
550
-
551
- q = apply_rotary_emb_func(
552
- q,
553
- cos,
554
- sin,
555
- interleaved=self.interleaved,
556
- inplace=True
557
- )
558
- k = apply_rotary_emb_func(
559
- k,
560
- cos,
561
- sin,
562
- interleaved=self.interleaved,
563
- inplace=True
564
- )
565
- return q, k
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, Tuple, Union
2
+
3
+ import torch
4
+ from einops import rearrange, repeat
5
+ import torch.nn.functional as F
6
+
7
+ #import triton
8
+ #import triton.language as tl
9
+
10
+
11
+ # @triton.autotune(
12
+ # configs=[
13
+ # triton.Config({"BLOCK_M": 2}),
14
+ # triton.Config({"BLOCK_M": 4}),
15
+ # triton.Config({"BLOCK_M": 8}),
16
+ # triton.Config({"BLOCK_M": 16}),
17
+ # ],
18
+ # key=["CACHE_KEY_SEQLEN", "BLOCK_K", "INTERLEAVED"],
19
+ # )
20
+ #@triton.jit
21
+ # def rotary_kernel(
22
+ # OUT, # Pointers to matrices
23
+ # X,
24
+ # COS,
25
+ # SIN,
26
+ # CU_SEQLENS,
27
+ # SEQLEN_OFFSETS, # this could be int or a pointer
28
+ # # Matrix dimensions
29
+ # seqlen,
30
+ # nheads,
31
+ # rotary_dim,
32
+ # seqlen_ro,
33
+ # CACHE_KEY_SEQLEN,
34
+ # # strides
35
+ # stride_out_batch,
36
+ # stride_out_nheads,
37
+ # stride_out_seqlen,
38
+ # stride_out_headdim,
39
+ # stride_x_batch,
40
+ # stride_x_nheads,
41
+ # stride_x_seqlen,
42
+ # stride_x_headdim,
43
+ # # Meta-parameters
44
+ # BLOCK_K: tl.constexpr,
45
+ # IS_SEQLEN_OFFSETS_TENSOR: tl.constexpr,
46
+ # IS_VARLEN: tl.constexpr,
47
+ # INTERLEAVED: tl.constexpr,
48
+ # CONJUGATE: tl.constexpr,
49
+ # BLOCK_M: tl.constexpr,
50
+ # ):
51
+ # pid_m = tl.program_id(axis=0)
52
+ # pid_batch = tl.program_id(axis=1)
53
+ # pid_head = tl.program_id(axis=2)
54
+ # rotary_dim_half = rotary_dim // 2
55
+
56
+ # if not IS_VARLEN:
57
+ # X = X + pid_batch * stride_x_batch + pid_head * stride_x_nheads
58
+ # OUT = OUT + pid_batch * stride_out_batch + pid_head * stride_out_nheads
59
+ # COS = COS + pid_batch * seqlen_ro * rotary_dim_half
60
+ # SIN = SIN + pid_batch * seqlen_ro * rotary_dim_half
61
+ # else:
62
+ # start_idx = tl.load(CU_SEQLENS + pid_batch)
63
+ # seqlen = tl.load(CU_SEQLENS + pid_batch + 1) - start_idx
64
+ # X = X + start_idx * stride_x_seqlen + pid_head * stride_x_nheads
65
+ # OUT = OUT + start_idx * stride_out_seqlen + pid_head * stride_out_nheads
66
+
67
+ # if pid_m * BLOCK_M >= seqlen:
68
+ # return
69
+ # rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
70
+ # if not IS_SEQLEN_OFFSETS_TENSOR:
71
+ # rm_cs = rm + SEQLEN_OFFSETS
72
+ # else:
73
+ # rm_cs = rm + tl.load(SEQLEN_OFFSETS + pid_batch)
74
+ # rk = tl.arange(0, BLOCK_K)
75
+ # rk_half = tl.arange(0, BLOCK_K // 2)
76
+
77
+ # if not INTERLEAVED:
78
+ # # Load the 1st and 2nd halves of X, do calculation, then store to 1st and 2nd halves of OUT
79
+ # X = X + (rm[:, None] * stride_x_seqlen + rk_half[None, :] * stride_x_headdim)
80
+ # COS = COS + (rm_cs[:, None] * rotary_dim_half + rk_half[None, :])
81
+ # SIN = SIN + (rm_cs[:, None] * rotary_dim_half + rk_half[None, :])
82
+ # cos = tl.load(
83
+ # COS, mask=(rm_cs[:, None] < seqlen_ro) & (rk_half[None, :] < rotary_dim_half), other=1.0
84
+ # )
85
+ # sin = tl.load(
86
+ # SIN, mask=(rm_cs[:, None] < seqlen_ro) & (rk_half[None, :] < rotary_dim_half), other=0.0
87
+ # )
88
+ # x0 = tl.load(
89
+ # X, mask=(rm[:, None] < seqlen) & (rk_half[None, :] < rotary_dim_half), other=0.0
90
+ # )
91
+ # x1 = tl.load(
92
+ # X + rotary_dim_half * stride_x_headdim,
93
+ # mask=(rm[:, None] < seqlen) & (rk_half[None, :] < rotary_dim_half),
94
+ # other=0.0,
95
+ # )
96
+ # if CONJUGATE:
97
+ # sin = -sin
98
+ # o0 = x0 * cos - x1 * sin
99
+ # o1 = x0 * sin + x1 * cos
100
+ # # write back result
101
+ # OUT = OUT + (rm[:, None] * stride_out_seqlen + rk_half[None, :] * stride_out_headdim)
102
+ # tl.store(OUT, o0, mask=(rm[:, None] < seqlen) & (rk_half[None, :] < rotary_dim_half))
103
+ # tl.store(
104
+ # OUT + rotary_dim_half * stride_out_headdim,
105
+ # o1,
106
+ # mask=(rm[:, None] < seqlen) & (rk_half[None, :] < rotary_dim_half),
107
+ # )
108
+ # else:
109
+ # # We don't want to load X[0, 2, 4, ...] and X[1, 3, 5, ...] separately since both are slow.
110
+ # # Instead, we load x0 = X[0, 1, 2, 3, ...] and x1 = X[1, 0, 3, 2, ...].
111
+ # # Loading x0 will be fast but x1 will be slow.
112
+ # # Then we load cos = COS[0, 0, 1, 1, ...] and sin = SIN[0, 0, 1, 1, ...].
113
+ # # Then we do the calculation and use tl.where to pick put the right outputs for the even
114
+ # # and for the odd indices.
115
+ # rk_swap = rk + ((rk + 1) % 2) * 2 - 1 # 1, 0, 3, 2, 5, 4, ...
116
+ # rk_repeat = tl.arange(0, BLOCK_K) // 2
117
+ # X0 = X + (rm[:, None] * stride_x_seqlen + rk[None, :] * stride_x_headdim)
118
+ # X1 = X + (rm[:, None] * stride_x_seqlen + rk_swap[None, :] * stride_x_headdim)
119
+ # COS = COS + (rm_cs[:, None] * rotary_dim_half + rk_repeat[None, :])
120
+ # SIN = SIN + (rm_cs[:, None] * rotary_dim_half + rk_repeat[None, :])
121
+ # cos = tl.load(
122
+ # COS,
123
+ # mask=(rm_cs[:, None] < seqlen_ro) & (rk_repeat[None, :] < rotary_dim_half),
124
+ # other=1.0,
125
+ # ).to(tl.float32)
126
+ # sin = tl.load(
127
+ # SIN,
128
+ # mask=(rm_cs[:, None] < seqlen_ro) & (rk_repeat[None, :] < rotary_dim_half),
129
+ # other=0.0,
130
+ # ).to(tl.float32)
131
+ # x0 = tl.load(X0, mask=(rm[:, None] < seqlen) & (rk[None, :] < rotary_dim), other=0.0).to(
132
+ # tl.float32
133
+ # )
134
+ # x1 = tl.load(
135
+ # X1, mask=(rm[:, None] < seqlen) & (rk_swap[None, :] < rotary_dim), other=0.0
136
+ # ).to(tl.float32)
137
+ # if CONJUGATE:
138
+ # sin = -sin
139
+ # x0_cos = x0 * cos
140
+ # x1_sin = x1 * sin
141
+ # out = tl.where(rk[None, :] % 2 == 0, x0_cos - x1_sin, x0_cos + x1_sin)
142
+ # OUT = OUT + (rm[:, None] * stride_out_seqlen + rk[None, :] * stride_out_headdim)
143
+ # tl.store(OUT, out, mask=(rm[:, None] < seqlen) & (rk[None, :] < rotary_dim))
144
+
145
+
146
+ # def apply_rotary(
147
+ # x: torch.Tensor,
148
+ # cos: torch.Tensor,
149
+ # sin: torch.Tensor,
150
+ # seqlen_offsets: Union[int, torch.Tensor] = 0,
151
+ # cu_seqlens: Optional[torch.Tensor] = None,
152
+ # max_seqlen: Optional[int] = None,
153
+ # interleaved=False,
154
+ # inplace=False,
155
+ # conjugate=False,
156
+ # ) -> torch.Tensor:
157
+ # """
158
+ # Arguments:
159
+ # x: (batch, seqlen, nheads, headdim) if cu_seqlens is None
160
+ # else (total_seqlen, nheads, headdim).
161
+ # cos: (seqlen_ro, rotary_dim / 2)
162
+ # sin: (seqlen_ro, rotary_dim / 2)
163
+ # seqlen_offsets: integer or integer tensor of size (batch,)
164
+ # cu_seqlens: (batch + 1,) or None
165
+ # max_seqlen: int
166
+ # Returns:
167
+ # y: (batch, seqlen, nheads, headdim)
168
+ # """
169
+
170
+ # batch, nheads, seqlen, headdim = x.shape
171
+
172
+ # batch_ro, seqlen_ro, rotary_dim = cos.shape
173
+
174
+ # assert batch == batch_ro
175
+ # assert sin.shape == cos.shape
176
+ # rotary_dim *= 2
177
+ # assert rotary_dim <= headdim, "rotary_dim must be <= headdim"
178
+ # assert headdim <= 256, "Only support headdim <= 256"
179
+
180
+ # assert seqlen_ro >= seqlen, "seqlen_ro must be >= seqlen"
181
+
182
+ # assert (
183
+ # cos.dtype == sin.dtype
184
+ # ), f"cos and sin must have the same dtype, got {cos.dtype} and {sin.dtype}"
185
+ # assert (
186
+ # x.dtype == cos.dtype
187
+ # ), f"Input and cos/sin must have the same dtype, got {x.dtype} and {cos.dtype}"
188
+
189
+ # cos, sin = cos.contiguous(), sin.contiguous()
190
+ # if isinstance(seqlen_offsets, torch.Tensor):
191
+ # assert seqlen_offsets.shape == (batch,)
192
+ # assert seqlen_offsets.dtype in [torch.int32, torch.int64]
193
+ # seqlen_offsets = seqlen_offsets.contiguous()
194
+ # else:
195
+ # assert seqlen_offsets + seqlen <= seqlen_ro
196
+
197
+ # output = torch.empty_like(x) if not inplace else x
198
+ # if rotary_dim < headdim and not inplace:
199
+ # output[..., rotary_dim:].copy_(x[..., rotary_dim:])
200
+
201
+ # BLOCK_K = (
202
+ # 32
203
+ # if rotary_dim <= 32
204
+ # else (64 if rotary_dim <= 64 else (128 if rotary_dim <= 128 else 256))
205
+ # )
206
+ # grid = lambda META: (triton.cdiv(seqlen, META["BLOCK_M"]), batch, nheads) # noqa
207
+ # BLOCK_M = 4 if interleaved else (8 if rotary_dim <= 64 else 4)
208
+
209
+ # # Need this, otherwise Triton tries to launch from cuda:0 and we get
210
+ # # ValueError: Pointer argument (at 0) cannot be accessed from Triton (cpu tensor?)
211
+ # with torch.cuda.device(x.device.index):
212
+ # rotary_kernel[grid](
213
+ # output, # data ptrs
214
+ # x,
215
+ # cos,
216
+ # sin,
217
+ # cu_seqlens,
218
+ # seqlen_offsets,
219
+ # seqlen, # shapes
220
+ # nheads,
221
+ # rotary_dim,
222
+ # seqlen_ro,
223
+ # seqlen // 128, # key for triton cache (limit number of compilations)
224
+ # output.stride(0), # batch_strides
225
+ # output.stride(-3), # nheads_stride
226
+ # output.stride(-2), # seqlen_stride
227
+ # output.stride(-1), # headdim_stride
228
+ # x.stride(0), # batch_strides
229
+ # x.stride(-3), # nheads stride
230
+ # x.stride(-2), # seqlen stride
231
+ # x.stride(-1), # headdim stride
232
+ # BLOCK_K,
233
+ # isinstance(seqlen_offsets, torch.Tensor),
234
+ # False,
235
+ # interleaved,
236
+ # conjugate,
237
+ # BLOCK_M,
238
+ # )
239
+ # return output
240
+ def apply_rotary(
241
+ x: torch.Tensor,
242
+ cos: torch.Tensor,
243
+ sin: torch.Tensor,
244
+ seqlen_offsets: Union[int, torch.Tensor] = 0,
245
+ cu_seqlens: Optional[torch.Tensor] = None,
246
+ max_seqlen: Optional[int] = None,
247
+ interleaved=False,
248
+ inplace=False,
249
+ conjugate=False,
250
+ ) -> torch.Tensor:
251
+ """
252
+ Arguments:
253
+ x: (batch, seqlen, nheads, headdim) if cu_seqlens is None
254
+ else (total_seqlen, nheads, headdim).
255
+ cos: (seqlen_ro, rotary_dim / 2)
256
+ sin: (seqlen_ro, rotary_dim / 2)
257
+ seqlen_offsets: integer or integer tensor of size (batch,)
258
+ cu_seqlens: (batch + 1,) or None
259
+ max_seqlen: int
260
+ Returns:
261
+ y: (batch, seqlen, nheads, headdim)
262
+ """
263
+
264
+ batch, nheads, seqlen, headdim = x.shape
265
+
266
+ batch_ro, seqlen_ro, rotary_dim = cos.shape
267
+
268
+ assert batch == batch_ro
269
+ assert sin.shape == cos.shape
270
+ rotary_dim *= 2
271
+ assert rotary_dim <= headdim, "rotary_dim must be <= headdim"
272
+ assert headdim <= 256, "Only support headdim <= 256"
273
+
274
+ assert seqlen_ro >= seqlen, "seqlen_ro must be >= seqlen"
275
+
276
+ assert (
277
+ cos.dtype == sin.dtype
278
+ ), f"cos and sin must have the same dtype, got {cos.dtype} and {sin.dtype}"
279
+ assert (
280
+ x.dtype == cos.dtype
281
+ ), f"Input and cos/sin must have the same dtype, got {x.dtype} and {cos.dtype}"
282
+
283
+ cos, sin = cos.contiguous(), sin.contiguous()
284
+ if isinstance(seqlen_offsets, torch.Tensor):
285
+ assert seqlen_offsets.shape == (batch,)
286
+ assert seqlen_offsets.dtype in [torch.int32, torch.int64]
287
+ seqlen_offsets = seqlen_offsets.contiguous()
288
+ else:
289
+ assert seqlen_offsets + seqlen <= seqlen_ro
290
+
291
+ output = torch.empty_like(x) if not inplace else x
292
+ if rotary_dim < headdim and not inplace:
293
+ output[..., rotary_dim:].copy_(x[..., rotary_dim:])
294
+
295
+ rotary_dim_half = rotary_dim // 2
296
+ for b in range(batch):
297
+ for h in range(nheads):
298
+ for s in range(seqlen):
299
+ idx = s + seqlen_offsets if isinstance(seqlen_offsets, int) else s + seqlen_offsets[b]
300
+ if idx >= seqlen_ro:
301
+ continue
302
+
303
+ cos_idx = cos[b, idx, :rotary_dim_half]
304
+ sin_idx = sin[b, idx, :rotary_dim_half]
305
+ if conjugate:
306
+ sin_idx = -sin_idx
307
+
308
+ if not interleaved:
309
+ x0 = x[b, h, s, :rotary_dim_half]
310
+ x1 = x[b, h, s, rotary_dim_half:rotary_dim]
311
+ o0 = x0 * cos_idx - x1 * sin_idx
312
+ o1 = x0 * sin_idx + x1 * cos_idx
313
+ output[b, h, s, :rotary_dim_half] = o0
314
+ output[b, h, s, rotary_dim_half:rotary_dim] = o1
315
+ else:
316
+ for i in range(rotary_dim):
317
+ if i % 2 == 0:
318
+ output[b, h, s, i] = x[b, h, s, i] * cos_idx[i // 2] - x[b, h, s, i + 1] * sin_idx[i // 2]
319
+ else:
320
+ output[b, h, s, i] = x[b, h, s, i - 1] * sin_idx[i // 2] + x[b, h, s, i] * cos_idx[i // 2]
321
+
322
+ return output
323
+
324
+ def apply_rotary_optimized(
325
+ x: torch.Tensor,
326
+ cos: torch.Tensor,
327
+ sin: torch.Tensor,
328
+ seqlen_offsets: Union[int, torch.Tensor] = 0,
329
+ cu_seqlens: Optional[torch.Tensor] = None,
330
+ max_seqlen: Optional[int] = None,
331
+ interleaved=False,
332
+ inplace=False,
333
+ conjugate=False,
334
+ ) -> torch.Tensor:
335
+ batch, nheads, seqlen, headdim = x.shape
336
+ batch_ro, seqlen_ro, rotary_dim = cos.shape
337
+
338
+ assert batch == batch_ro
339
+ assert sin.shape == cos.shape
340
+ rotary_dim *= 2
341
+ assert rotary_dim <= headdim, "rotary_dim must be <= headdim"
342
+ assert headdim <= 256, "Only support headdim <= 256"
343
+ assert seqlen_ro >= seqlen, "seqlen_ro must be >= seqlen"
344
+ assert cos.dtype == sin.dtype, f"cos and sin must have the same dtype, got {cos.dtype} and {sin.dtype}"
345
+ assert x.dtype == cos.dtype, f"Input and cos/sin must have the same dtype, got {x.dtype} and {cos.dtype}"
346
+
347
+ cos, sin = cos.contiguous(), sin.contiguous()
348
+ if isinstance(seqlen_offsets, torch.Tensor):
349
+ assert seqlen_offsets.shape == (batch,)
350
+ assert seqlen_offsets.dtype in [torch.int32, torch.int64]
351
+ seqlen_offsets = seqlen_offsets.contiguous()
352
+ else:
353
+ assert seqlen_offsets + seqlen <= seqlen_ro
354
+ seqlen_offsets = torch.full((batch,), seqlen_offsets, device=x.device, dtype=torch.long)
355
+
356
+ output = torch.empty_like(x) if not inplace else x
357
+ if rotary_dim < headdim and not inplace:
358
+ output[..., rotary_dim:].copy_(x[..., rotary_dim:])
359
+
360
+ rotary_dim_half = rotary_dim // 2
361
+
362
+ # Create indices for gathering
363
+ seq_indices = torch.arange(seqlen, device=x.device).unsqueeze(0) + seqlen_offsets.unsqueeze(1)
364
+ seq_indices = seq_indices.clamp(max=seqlen_ro - 1)
365
+
366
+ # Gather cos and sin values
367
+ cos_gathered = cos.gather(1, seq_indices.unsqueeze(-1).expand(-1, -1, rotary_dim_half))
368
+ sin_gathered = sin.gather(1, seq_indices.unsqueeze(-1).expand(-1, -1, rotary_dim_half))
369
+
370
+ if conjugate:
371
+ sin_gathered = -sin_gathered
372
+
373
+ if not interleaved:
374
+ x_rotary = x[..., :rotary_dim].view(batch, nheads, seqlen, 2, -1)
375
+ x0, x1 = x_rotary.unbind(dim=-2)
376
+
377
+ o0 = x0 * cos_gathered.unsqueeze(1) - x1 * sin_gathered.unsqueeze(1)
378
+ o1 = x0 * sin_gathered.unsqueeze(1) + x1 * cos_gathered.unsqueeze(1)
379
+
380
+ output[..., :rotary_dim] = torch.stack([o0, o1], dim=-2).view(batch, nheads, seqlen, -1)
381
+ else:
382
+ x_rotary = x[..., :rotary_dim].view(batch, nheads, seqlen, rotary_dim // 2, 2)
383
+ x0, x1 = x_rotary.unbind(dim=-1)
384
+
385
+ o0 = x0 * cos_gathered.unsqueeze(1) - x1 * sin_gathered.unsqueeze(1)
386
+ o1 = x0 * sin_gathered.unsqueeze(1) + x1 * cos_gathered.unsqueeze(1)
387
+
388
+ output[..., :rotary_dim] = torch.stack([o0, o1], dim=-1).view(batch, nheads, seqlen, -1)
389
+
390
+ return output
391
+ class ApplyRotaryEmb(torch.autograd.Function):
392
+ @staticmethod
393
+ def forward(
394
+ ctx,
395
+ x,
396
+ cos,
397
+ sin,
398
+ interleaved=False,
399
+ inplace=False,
400
+ seqlen_offsets: Union[int, torch.Tensor] = 0,
401
+ cu_seqlens: Optional[torch.Tensor] = None,
402
+ max_seqlen: Optional[int] = None,
403
+ ):
404
+ out = apply_rotary_optimized(
405
+ x,
406
+ cos,
407
+ sin,
408
+ seqlen_offsets=seqlen_offsets,
409
+ cu_seqlens=cu_seqlens,
410
+ interleaved=interleaved,
411
+ inplace=inplace,
412
+ )
413
+ if isinstance(seqlen_offsets, int):
414
+ ctx.save_for_backward(cos, sin, cu_seqlens) # Can't save int with save_for_backward
415
+ ctx.seqlen_offsets = seqlen_offsets
416
+ else:
417
+ ctx.save_for_backward(cos, sin, cu_seqlens, seqlen_offsets)
418
+ ctx.seqlen_offsets = None
419
+ ctx.interleaved = interleaved
420
+ ctx.inplace = inplace
421
+ ctx.max_seqlen = max_seqlen
422
+ return out if not inplace else x
423
+
424
+ @staticmethod
425
+ def backward(ctx, do):
426
+ seqlen_offsets = ctx.seqlen_offsets
427
+ if seqlen_offsets is None:
428
+ cos, sin, cu_seqlens, seqlen_offsets = ctx.saved_tensors
429
+ else:
430
+ cos, sin, cu_seqlens = ctx.saved_tensors
431
+ # TD [2023-09-02]: For some reason Triton (2.0.0.post1) errors with
432
+ # "[CUDA]: invalid device context", and cloning makes it work. Idk why. Triton 2.1.0 works.
433
+ if not ctx.interleaved and not ctx.inplace:
434
+ do = do.clone()
435
+ dx = apply_rotary(
436
+ do,
437
+ cos,
438
+ sin,
439
+ seqlen_offsets=seqlen_offsets,
440
+ cu_seqlens=cu_seqlens,
441
+ max_seqlen=ctx.max_seqlen,
442
+ interleaved=ctx.interleaved,
443
+ inplace=ctx.inplace,
444
+ conjugate=True,
445
+ )
446
+ return dx, None, None, None, None, None, None, None
447
+
448
+
449
+ def apply_rotary_emb(
450
+ x,
451
+ cos,
452
+ sin,
453
+ interleaved=False,
454
+ inplace=False,
455
+ seqlen_offsets: Union[int, torch.Tensor] = 0,
456
+ cu_seqlens: Optional[torch.Tensor] = None,
457
+ max_seqlen: Optional[int] = None,
458
+ ):
459
+ """
460
+ Arguments:
461
+ x: (batch_size, seqlen, nheads, headdim) if cu_seqlens is None
462
+ else (total_seqlen, nheads, headdim)
463
+ cos, sin: (seqlen_rotary, rotary_dim / 2)
464
+ interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead
465
+ of 1st half and 2nd half (GPT-NeoX style).
466
+ inplace: if True, apply rotary embedding in-place.
467
+ seqlen_offsets: (batch_size,) or int. Each sequence in x is shifted by this amount.
468
+ Most commonly used in inference when we have KV cache.
469
+ cu_seqlens: (batch + 1,) or None
470
+ max_seqlen: int
471
+ Return:
472
+ out: (batch_size, seqlen, nheads, headdim) if cu_seqlens is None
473
+ else (total_seqlen, nheads, headdim)
474
+ rotary_dim must be <= headdim
475
+ Apply rotary embedding to the first rotary_dim of x.
476
+ """
477
+ return ApplyRotaryEmb.apply(
478
+ x, cos, sin, interleaved, inplace, seqlen_offsets, cu_seqlens, max_seqlen
479
+ )
480
+
481
+
482
+ # For backward compatibility
483
+ apply_rotary_emb_func = apply_rotary_emb
484
+
485
+
486
+ class FastRotaryEmbedding(torch.nn.Module):
487
+ """
488
+ The rotary position embeddings from RoFormer_ (Su et. al).
489
+ A crucial insight from the method is that the query and keys are
490
+ transformed by rotation matrices which depend on the relative positions.
491
+
492
+ Other implementations are available in the Rotary Transformer repo_ and in
493
+ GPT-NeoX_, GPT-NeoX was an inspiration
494
+
495
+ .. _RoFormer: https://arxiv.org/abs/2104.09864
496
+ .. _repo: https://github.com/ZhuiyiTechnology/roformer
497
+ .. _GPT-NeoX: https://github.com/EleutherAI/gpt-neox
498
+
499
+ If scale_base is not None, this implements XPos (Sun et al., https://arxiv.org/abs/2212.10554).
500
+ A recommended value for scale_base is 512: https://github.com/HazyResearch/flash-attention/issues/96
501
+ Reference: https://github.com/sunyt32/torchscale/blob/main/torchscale/component/xpos_relative_position.py
502
+ """
503
+
504
+ def __init__(
505
+ self,
506
+ dim: int,
507
+ base=10000,
508
+ interleaved=False,
509
+ scale_base=None,
510
+ pos_idx_in_fp32=True,
511
+ device=None,
512
+ ):
513
+ """
514
+ interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead
515
+ of 1st half and 2nd half (GPT-NeoX style).
516
+ pos_idx_in_fp32: if True, the position indices [0.0, ..., seqlen - 1] are in fp32,
517
+ otherwise they might be in lower precision.
518
+ This option was added because previously (before 2023-07-02), when we construct
519
+ the position indices, we use the dtype of self.inv_freq. In most cases this would
520
+ be fp32, but if the model is trained in pure bf16 (not mixed precision), then
521
+ self.inv_freq would be bf16, and the position indices are also in bf16.
522
+ Because of the limited precision of bf16 (e.g. 1995.0 is rounded to 2000.0), the
523
+ embeddings for some positions will coincide.
524
+ To maintain compatibility with models previously trained in pure bf16,
525
+ we add this option.
526
+ """
527
+ super().__init__()
528
+ self.dim = dim
529
+ self.base = base
530
+ self.pos_idx_in_fp32 = pos_idx_in_fp32
531
+ # Generate and save the inverse frequency buffer (non trainable)
532
+ inv_freq = self._compute_inv_freq(device)
533
+ self.register_buffer("inv_freq", inv_freq)
534
+ self.interleaved = interleaved
535
+ self.scale_base = scale_base
536
+ scale = (
537
+ (torch.arange(0, dim, 2, device=device, dtype=torch.float32) + 0.4 * dim) / (1.4 * dim)
538
+ if scale_base is not None
539
+ else None
540
+ )
541
+ self.register_buffer("scale", scale, persistent=False)
542
+
543
+ self._seq_len_cached = 0
544
+ self._cos_cached = None
545
+ self._sin_cached = None
546
+ self._cos_k_cached = None
547
+ self._sin_k_cached = None
548
+ self.cos = None
549
+ self.sin = None
550
+
551
+ def _compute_inv_freq(self, device=None):
552
+ return 1.0 / (
553
+ self.base
554
+ ** (torch.arange(0, self.dim, 2, device=device) / self.dim)
555
+ # ** (torch.arange(0, self.dim, 2, device=device).float() / self.dim)
556
+ )
557
+
558
+ def _update_cos_sin_cache(self, seqlen, position_id, device=None, dtype=None):
559
+
560
+ if (
561
+ seqlen > self._seq_len_cached
562
+ ):
563
+ self._seq_len_cached = seqlen
564
+ # We want fp32 here, not self.inv_freq.dtype, since the model could be loaded in bf16
565
+ # And the output of arange can be quite large, so bf16 would lose a lot of precision.
566
+ # However, for compatibility reason, we add an option to use the dtype of self.inv_freq.
567
+ if self.pos_idx_in_fp32:
568
+ t = torch.arange(seqlen, device=device, dtype=torch.float32)
569
+ # We want fp32 here as well since inv_freq will be multiplied with t, and the output
570
+ # will be large. Having it in bf16 will lose a lot of precision and cause the
571
+ # cos & sin output to change significantly.
572
+ # We want to recompute self.inv_freq if it was not loaded in fp32
573
+ if self.inv_freq.dtype != torch.float32:
574
+ inv_freq = self._compute_inv_freq(device=device)
575
+ else:
576
+ inv_freq = self.inv_freq
577
+ else:
578
+ t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype)
579
+ inv_freq = self.inv_freq
580
+ freqs = torch.einsum("i,j->ij", t, inv_freq)
581
+ if self.scale is None:
582
+ self._cos_cached = torch.cos(freqs).to(dtype)
583
+ self._sin_cached = torch.sin(freqs).to(dtype)
584
+
585
+ else:
586
+ power = (
587
+ torch.arange(seqlen, dtype=self.scale.dtype, device=self.scale.device)
588
+ - seqlen // 2
589
+ ) / self.scale_base
590
+ scale = self.scale.to(device=power.device) ** rearrange(power, "s -> s 1")
591
+ # We want the multiplication by scale to happen in fp32
592
+ self._cos_cached = (torch.cos(freqs) * scale).to(dtype)
593
+ self._sin_cached = (torch.sin(freqs) * scale).to(dtype)
594
+ self._cos_k_cached = (torch.cos(freqs) / scale).to(dtype)
595
+ self._sin_k_cached = (torch.sin(freqs) / scale).to(dtype)
596
+
597
+ def forward(
598
+ self,
599
+ q: torch.Tensor,
600
+ k: torch.Tensor,
601
+ position_ids: torch.Tensor,
602
+ max_seqlen,
603
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
604
+ """
605
+ q: (batch, nheads, seqlen, headdim)
606
+ k: (batch, nheads, seqlen, headdim)
607
+ position_id: (batch, seqlen)
608
+ max_seqlen: int
609
+ layer_id: int
610
+ only if layer_id == 0, then update cons and sin
611
+ Apply rotary embedding *inplace* to q k.
612
+ """
613
+
614
+ self._update_cos_sin_cache(max_seqlen, position_ids, device=q.device, dtype=q.dtype)
615
+ cos, sin = F.embedding(position_ids, self._cos_cached), F.embedding(position_ids, self._sin_cached)
616
+
617
+ q = apply_rotary_emb_func(
618
+ q,
619
+ cos,
620
+ sin,
621
+ interleaved=self.interleaved,
622
+ inplace=True
623
+ )
624
+ k = apply_rotary_emb_func(
625
+ k,
626
+ cos,
627
+ sin,
628
+ interleaved=self.interleaved,
629
+ inplace=True
630
+ )
631
+ return q, k