Spaces:
Sleeping
Sleeping
File size: 2,681 Bytes
adf0368 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 |
# Copyright (c) 2023, Tri Dao.
from typing import Optional, Union
import torch
def apply_rotary_emb_torch(
x,
cos,
sin,
interleaved=False,
inplace=False,
seqlen_offsets=0,
cu_seqlens=None,
max_seqlen=None,
):
# Only supports the basic (not interleaved, not variable-length) case.
rotary_dim = cos.shape[1] * 2
x1 = x[..., :rotary_dim]
x2 = x[..., rotary_dim:]
# Split [even, odd] pairs
x1_1, x1_2 = x1[..., ::2], x1[..., 1::2] # (..., rotary_dim/2)
# Reshape cos/sin for broadcasting
# x: [batch, seqlen, nheads, rotary_dim]
# cos/sin: [seqlen, rotary_dim/2]
# reshape to [1, seqlen, 1, rotary_dim/2] to broadcast
cos = cos.unsqueeze(0).unsqueeze(2)
sin = sin.unsqueeze(0).unsqueeze(2)
rot_x1 = x1_1 * cos - x1_2 * sin
rot_x2 = x1_1 * sin + x1_2 * cos
# Interleave last dimension: (..., rotary_dim/2, 2) -> (..., rotary_dim)
rot_x = torch.stack([rot_x1, rot_x2], dim=-1).reshape_as(x1)
out = torch.cat([rot_x, x2], dim=-1)
return out
def apply_rotary_emb(
x,
cos,
sin,
interleaved=False,
inplace=False,
seqlen_offsets: Union[int, torch.Tensor] = 0,
cu_seqlens: Optional[torch.Tensor] = None,
max_seqlen: Optional[int] = None,
):
"""
Arguments:
x: (batch_size, seqlen, nheads, headdim) if cu_seqlens is None
else (total_seqlen, nheads, headdim)
cos, sin: (seqlen_rotary, rotary_dim / 2)
interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead
of 1st half and 2nd half (GPT-NeoX style).
inplace: if True, apply rotary embedding in-place.
seqlen_offsets: (batch_size,) or int. Each sequence in x is shifted by this amount.
Most commonly used in inference when we have KV cache.
cu_seqlens: (batch + 1,) or None
max_seqlen: int
Return:
out: (batch_size, seqlen, nheads, headdim) if cu_seqlens is None
else (total_seqlen, nheads, headdim)
rotary_dim must be <= headdim
Apply rotary embedding to the first rotary_dim of x.
"""
# We are forcing the use of the pure PyTorch implementation (`apply_rotary_emb_torch`)
# for all devices. The custom Triton kernel (`ApplyRotaryEmb`) was causing a graph
# break in `torch.compile`, pushing expensive operations to the CPU.
# By using the pure PyTorch version, `torch.compile` can create a single, fully-optimized
# graph, which should resolve the CPU bottleneck and improve GPU utilization.
return apply_rotary_emb_torch(
x, cos, sin, interleaved, inplace, seqlen_offsets, cu_seqlens, max_seqlen
)
|