|
from typing import Optional, Tuple, Union |
|
|
|
import torch |
|
from einops import rearrange, repeat |
|
import torch.nn.functional as F |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def apply_rotary( |
|
x: torch.Tensor, |
|
cos: torch.Tensor, |
|
sin: torch.Tensor, |
|
seqlen_offsets: Union[int, torch.Tensor] = 0, |
|
cu_seqlens: Optional[torch.Tensor] = None, |
|
max_seqlen: Optional[int] = None, |
|
interleaved=False, |
|
inplace=False, |
|
conjugate=False, |
|
) -> torch.Tensor: |
|
""" |
|
Arguments: |
|
x: (batch, seqlen, nheads, headdim) if cu_seqlens is None |
|
else (total_seqlen, nheads, headdim). |
|
cos: (seqlen_ro, rotary_dim / 2) |
|
sin: (seqlen_ro, rotary_dim / 2) |
|
seqlen_offsets: integer or integer tensor of size (batch,) |
|
cu_seqlens: (batch + 1,) or None |
|
max_seqlen: int |
|
Returns: |
|
y: (batch, seqlen, nheads, headdim) |
|
""" |
|
|
|
batch, nheads, seqlen, headdim = x.shape |
|
|
|
batch_ro, seqlen_ro, rotary_dim = cos.shape |
|
|
|
assert batch == batch_ro |
|
assert sin.shape == cos.shape |
|
rotary_dim *= 2 |
|
assert rotary_dim <= headdim, "rotary_dim must be <= headdim" |
|
assert headdim <= 256, "Only support headdim <= 256" |
|
|
|
assert seqlen_ro >= seqlen, "seqlen_ro must be >= seqlen" |
|
|
|
assert ( |
|
cos.dtype == sin.dtype |
|
), f"cos and sin must have the same dtype, got {cos.dtype} and {sin.dtype}" |
|
assert ( |
|
x.dtype == cos.dtype |
|
), f"Input and cos/sin must have the same dtype, got {x.dtype} and {cos.dtype}" |
|
|
|
cos, sin = cos.contiguous(), sin.contiguous() |
|
if isinstance(seqlen_offsets, torch.Tensor): |
|
assert seqlen_offsets.shape == (batch,) |
|
assert seqlen_offsets.dtype in [torch.int32, torch.int64] |
|
seqlen_offsets = seqlen_offsets.contiguous() |
|
else: |
|
assert seqlen_offsets + seqlen <= seqlen_ro |
|
|
|
output = torch.empty_like(x) if not inplace else x |
|
if rotary_dim < headdim and not inplace: |
|
output[..., rotary_dim:].copy_(x[..., rotary_dim:]) |
|
|
|
rotary_dim_half = rotary_dim // 2 |
|
for b in range(batch): |
|
for h in range(nheads): |
|
for s in range(seqlen): |
|
idx = s + seqlen_offsets if isinstance(seqlen_offsets, int) else s + seqlen_offsets[b] |
|
if idx >= seqlen_ro: |
|
continue |
|
|
|
cos_idx = cos[b, idx, :rotary_dim_half] |
|
sin_idx = sin[b, idx, :rotary_dim_half] |
|
if conjugate: |
|
sin_idx = -sin_idx |
|
|
|
if not interleaved: |
|
x0 = x[b, h, s, :rotary_dim_half] |
|
x1 = x[b, h, s, rotary_dim_half:rotary_dim] |
|
o0 = x0 * cos_idx - x1 * sin_idx |
|
o1 = x0 * sin_idx + x1 * cos_idx |
|
output[b, h, s, :rotary_dim_half] = o0 |
|
output[b, h, s, rotary_dim_half:rotary_dim] = o1 |
|
else: |
|
for i in range(rotary_dim): |
|
if i % 2 == 0: |
|
output[b, h, s, i] = x[b, h, s, i] * cos_idx[i // 2] - x[b, h, s, i + 1] * sin_idx[i // 2] |
|
else: |
|
output[b, h, s, i] = x[b, h, s, i - 1] * sin_idx[i // 2] + x[b, h, s, i] * cos_idx[i // 2] |
|
|
|
return output |
|
|
|
def apply_rotary_optimized( |
|
x: torch.Tensor, |
|
cos: torch.Tensor, |
|
sin: torch.Tensor, |
|
seqlen_offsets: Union[int, torch.Tensor] = 0, |
|
cu_seqlens: Optional[torch.Tensor] = None, |
|
max_seqlen: Optional[int] = None, |
|
interleaved=False, |
|
inplace=False, |
|
conjugate=False, |
|
) -> torch.Tensor: |
|
batch, nheads, seqlen, headdim = x.shape |
|
batch_ro, seqlen_ro, rotary_dim = cos.shape |
|
|
|
assert batch == batch_ro |
|
assert sin.shape == cos.shape |
|
rotary_dim *= 2 |
|
assert rotary_dim <= headdim, "rotary_dim must be <= headdim" |
|
assert headdim <= 256, "Only support headdim <= 256" |
|
assert seqlen_ro >= seqlen, "seqlen_ro must be >= seqlen" |
|
assert cos.dtype == sin.dtype, f"cos and sin must have the same dtype, got {cos.dtype} and {sin.dtype}" |
|
assert x.dtype == cos.dtype, f"Input and cos/sin must have the same dtype, got {x.dtype} and {cos.dtype}" |
|
|
|
cos, sin = cos.contiguous(), sin.contiguous() |
|
if isinstance(seqlen_offsets, torch.Tensor): |
|
assert seqlen_offsets.shape == (batch,) |
|
assert seqlen_offsets.dtype in [torch.int32, torch.int64] |
|
seqlen_offsets = seqlen_offsets.contiguous() |
|
else: |
|
assert seqlen_offsets + seqlen <= seqlen_ro |
|
seqlen_offsets = torch.full((batch,), seqlen_offsets, device=x.device, dtype=torch.long) |
|
|
|
output = torch.empty_like(x) if not inplace else x |
|
if rotary_dim < headdim and not inplace: |
|
output[..., rotary_dim:].copy_(x[..., rotary_dim:]) |
|
|
|
rotary_dim_half = rotary_dim // 2 |
|
|
|
|
|
seq_indices = torch.arange(seqlen, device=x.device).unsqueeze(0) + seqlen_offsets.unsqueeze(1) |
|
seq_indices = seq_indices.clamp(max=seqlen_ro - 1) |
|
|
|
|
|
cos_gathered = cos.gather(1, seq_indices.unsqueeze(-1).expand(-1, -1, rotary_dim_half)) |
|
sin_gathered = sin.gather(1, seq_indices.unsqueeze(-1).expand(-1, -1, rotary_dim_half)) |
|
|
|
if conjugate: |
|
sin_gathered = -sin_gathered |
|
|
|
if not interleaved: |
|
x_rotary = x[..., :rotary_dim].view(batch, nheads, seqlen, 2, -1) |
|
x0, x1 = x_rotary.unbind(dim=-2) |
|
|
|
o0 = x0 * cos_gathered.unsqueeze(1) - x1 * sin_gathered.unsqueeze(1) |
|
o1 = x0 * sin_gathered.unsqueeze(1) + x1 * cos_gathered.unsqueeze(1) |
|
|
|
output[..., :rotary_dim] = torch.stack([o0, o1], dim=-2).view(batch, nheads, seqlen, -1) |
|
else: |
|
x_rotary = x[..., :rotary_dim].view(batch, nheads, seqlen, rotary_dim // 2, 2) |
|
x0, x1 = x_rotary.unbind(dim=-1) |
|
|
|
o0 = x0 * cos_gathered.unsqueeze(1) - x1 * sin_gathered.unsqueeze(1) |
|
o1 = x0 * sin_gathered.unsqueeze(1) + x1 * cos_gathered.unsqueeze(1) |
|
|
|
output[..., :rotary_dim] = torch.stack([o0, o1], dim=-1).view(batch, nheads, seqlen, -1) |
|
|
|
return output |
|
class ApplyRotaryEmb(torch.autograd.Function): |
|
@staticmethod |
|
def forward( |
|
ctx, |
|
x, |
|
cos, |
|
sin, |
|
interleaved=False, |
|
inplace=False, |
|
seqlen_offsets: Union[int, torch.Tensor] = 0, |
|
cu_seqlens: Optional[torch.Tensor] = None, |
|
max_seqlen: Optional[int] = None, |
|
): |
|
out = apply_rotary_optimized( |
|
x, |
|
cos, |
|
sin, |
|
seqlen_offsets=seqlen_offsets, |
|
cu_seqlens=cu_seqlens, |
|
interleaved=interleaved, |
|
inplace=inplace, |
|
) |
|
if isinstance(seqlen_offsets, int): |
|
ctx.save_for_backward(cos, sin, cu_seqlens) |
|
ctx.seqlen_offsets = seqlen_offsets |
|
else: |
|
ctx.save_for_backward(cos, sin, cu_seqlens, seqlen_offsets) |
|
ctx.seqlen_offsets = None |
|
ctx.interleaved = interleaved |
|
ctx.inplace = inplace |
|
ctx.max_seqlen = max_seqlen |
|
return out if not inplace else x |
|
|
|
@staticmethod |
|
def backward(ctx, do): |
|
seqlen_offsets = ctx.seqlen_offsets |
|
if seqlen_offsets is None: |
|
cos, sin, cu_seqlens, seqlen_offsets = ctx.saved_tensors |
|
else: |
|
cos, sin, cu_seqlens = ctx.saved_tensors |
|
|
|
|
|
if not ctx.interleaved and not ctx.inplace: |
|
do = do.clone() |
|
dx = apply_rotary( |
|
do, |
|
cos, |
|
sin, |
|
seqlen_offsets=seqlen_offsets, |
|
cu_seqlens=cu_seqlens, |
|
max_seqlen=ctx.max_seqlen, |
|
interleaved=ctx.interleaved, |
|
inplace=ctx.inplace, |
|
conjugate=True, |
|
) |
|
return dx, None, None, None, None, None, None, None |
|
|
|
|
|
def apply_rotary_emb( |
|
x, |
|
cos, |
|
sin, |
|
interleaved=False, |
|
inplace=False, |
|
seqlen_offsets: Union[int, torch.Tensor] = 0, |
|
cu_seqlens: Optional[torch.Tensor] = None, |
|
max_seqlen: Optional[int] = None, |
|
): |
|
""" |
|
Arguments: |
|
x: (batch_size, seqlen, nheads, headdim) if cu_seqlens is None |
|
else (total_seqlen, nheads, headdim) |
|
cos, sin: (seqlen_rotary, rotary_dim / 2) |
|
interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead |
|
of 1st half and 2nd half (GPT-NeoX style). |
|
inplace: if True, apply rotary embedding in-place. |
|
seqlen_offsets: (batch_size,) or int. Each sequence in x is shifted by this amount. |
|
Most commonly used in inference when we have KV cache. |
|
cu_seqlens: (batch + 1,) or None |
|
max_seqlen: int |
|
Return: |
|
out: (batch_size, seqlen, nheads, headdim) if cu_seqlens is None |
|
else (total_seqlen, nheads, headdim) |
|
rotary_dim must be <= headdim |
|
Apply rotary embedding to the first rotary_dim of x. |
|
""" |
|
return ApplyRotaryEmb.apply( |
|
x, cos, sin, interleaved, inplace, seqlen_offsets, cu_seqlens, max_seqlen |
|
) |
|
|
|
|
|
|
|
apply_rotary_emb_func = apply_rotary_emb |
|
|
|
|
|
class FastRotaryEmbedding(torch.nn.Module): |
|
""" |
|
The rotary position embeddings from RoFormer_ (Su et. al). |
|
A crucial insight from the method is that the query and keys are |
|
transformed by rotation matrices which depend on the relative positions. |
|
|
|
Other implementations are available in the Rotary Transformer repo_ and in |
|
GPT-NeoX_, GPT-NeoX was an inspiration |
|
|
|
.. _RoFormer: https://arxiv.org/abs/2104.09864 |
|
.. _repo: https://github.com/ZhuiyiTechnology/roformer |
|
.. _GPT-NeoX: https://github.com/EleutherAI/gpt-neox |
|
|
|
If scale_base is not None, this implements XPos (Sun et al., https://arxiv.org/abs/2212.10554). |
|
A recommended value for scale_base is 512: https://github.com/HazyResearch/flash-attention/issues/96 |
|
Reference: https://github.com/sunyt32/torchscale/blob/main/torchscale/component/xpos_relative_position.py |
|
""" |
|
|
|
def __init__( |
|
self, |
|
dim: int, |
|
base=10000, |
|
interleaved=False, |
|
scale_base=None, |
|
pos_idx_in_fp32=True, |
|
device=None, |
|
): |
|
""" |
|
interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead |
|
of 1st half and 2nd half (GPT-NeoX style). |
|
pos_idx_in_fp32: if True, the position indices [0.0, ..., seqlen - 1] are in fp32, |
|
otherwise they might be in lower precision. |
|
This option was added because previously (before 2023-07-02), when we construct |
|
the position indices, we use the dtype of self.inv_freq. In most cases this would |
|
be fp32, but if the model is trained in pure bf16 (not mixed precision), then |
|
self.inv_freq would be bf16, and the position indices are also in bf16. |
|
Because of the limited precision of bf16 (e.g. 1995.0 is rounded to 2000.0), the |
|
embeddings for some positions will coincide. |
|
To maintain compatibility with models previously trained in pure bf16, |
|
we add this option. |
|
""" |
|
super().__init__() |
|
self.dim = dim |
|
self.base = base |
|
self.pos_idx_in_fp32 = pos_idx_in_fp32 |
|
|
|
inv_freq = self._compute_inv_freq(device) |
|
self.register_buffer("inv_freq", inv_freq) |
|
self.interleaved = interleaved |
|
self.scale_base = scale_base |
|
scale = ( |
|
(torch.arange(0, dim, 2, device=device, dtype=torch.float32) + 0.4 * dim) / (1.4 * dim) |
|
if scale_base is not None |
|
else None |
|
) |
|
self.register_buffer("scale", scale, persistent=False) |
|
|
|
self._seq_len_cached = 0 |
|
self._cos_cached = None |
|
self._sin_cached = None |
|
self._cos_k_cached = None |
|
self._sin_k_cached = None |
|
self.cos = None |
|
self.sin = None |
|
|
|
def _compute_inv_freq(self, device=None): |
|
return 1.0 / ( |
|
self.base |
|
** (torch.arange(0, self.dim, 2, device=device) / self.dim) |
|
|
|
) |
|
|
|
def _update_cos_sin_cache(self, seqlen, position_id, device=None, dtype=None): |
|
|
|
if ( |
|
seqlen > self._seq_len_cached |
|
): |
|
self._seq_len_cached = seqlen |
|
|
|
|
|
|
|
if self.pos_idx_in_fp32: |
|
t = torch.arange(seqlen, device=device, dtype=torch.float32) |
|
|
|
|
|
|
|
|
|
if self.inv_freq.dtype != torch.float32: |
|
inv_freq = self._compute_inv_freq(device=device) |
|
else: |
|
inv_freq = self.inv_freq |
|
else: |
|
t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype) |
|
inv_freq = self.inv_freq |
|
freqs = torch.einsum("i,j->ij", t, inv_freq) |
|
if self.scale is None: |
|
self._cos_cached = torch.cos(freqs).to(dtype) |
|
self._sin_cached = torch.sin(freqs).to(dtype) |
|
|
|
else: |
|
power = ( |
|
torch.arange(seqlen, dtype=self.scale.dtype, device=self.scale.device) |
|
- seqlen // 2 |
|
) / self.scale_base |
|
scale = self.scale.to(device=power.device) ** rearrange(power, "s -> s 1") |
|
|
|
self._cos_cached = (torch.cos(freqs) * scale).to(dtype) |
|
self._sin_cached = (torch.sin(freqs) * scale).to(dtype) |
|
self._cos_k_cached = (torch.cos(freqs) / scale).to(dtype) |
|
self._sin_k_cached = (torch.sin(freqs) / scale).to(dtype) |
|
|
|
def forward( |
|
self, |
|
q: torch.Tensor, |
|
k: torch.Tensor, |
|
position_ids: torch.Tensor, |
|
max_seqlen, |
|
) -> Tuple[torch.Tensor, torch.Tensor]: |
|
""" |
|
q: (batch, nheads, seqlen, headdim) |
|
k: (batch, nheads, seqlen, headdim) |
|
position_id: (batch, seqlen) |
|
max_seqlen: int |
|
layer_id: int |
|
only if layer_id == 0, then update cons and sin |
|
Apply rotary embedding *inplace* to q k. |
|
""" |
|
|
|
self._update_cos_sin_cache(max_seqlen, position_ids, device=q.device, dtype=q.dtype) |
|
cos, sin = F.embedding(position_ids, self._cos_cached), F.embedding(position_ids, self._sin_cached) |
|
|
|
q = apply_rotary_emb_func( |
|
q, |
|
cos, |
|
sin, |
|
interleaved=self.interleaved, |
|
inplace=True |
|
) |
|
k = apply_rotary_emb_func( |
|
k, |
|
cos, |
|
sin, |
|
interleaved=self.interleaved, |
|
inplace=True |
|
) |
|
return q, k |
|
|