hudsongouge's picture
Update space
adf0368
# Copyright (c) 2023, Tri Dao.
from typing import Optional, Union
import torch
def apply_rotary_emb_torch(
x,
cos,
sin,
interleaved=False,
inplace=False,
seqlen_offsets=0,
cu_seqlens=None,
max_seqlen=None,
):
# Only supports the basic (not interleaved, not variable-length) case.
rotary_dim = cos.shape[1] * 2
x1 = x[..., :rotary_dim]
x2 = x[..., rotary_dim:]
# Split [even, odd] pairs
x1_1, x1_2 = x1[..., ::2], x1[..., 1::2] # (..., rotary_dim/2)
# Reshape cos/sin for broadcasting
# x: [batch, seqlen, nheads, rotary_dim]
# cos/sin: [seqlen, rotary_dim/2]
# reshape to [1, seqlen, 1, rotary_dim/2] to broadcast
cos = cos.unsqueeze(0).unsqueeze(2)
sin = sin.unsqueeze(0).unsqueeze(2)
rot_x1 = x1_1 * cos - x1_2 * sin
rot_x2 = x1_1 * sin + x1_2 * cos
# Interleave last dimension: (..., rotary_dim/2, 2) -> (..., rotary_dim)
rot_x = torch.stack([rot_x1, rot_x2], dim=-1).reshape_as(x1)
out = torch.cat([rot_x, x2], dim=-1)
return out
def apply_rotary_emb(
x,
cos,
sin,
interleaved=False,
inplace=False,
seqlen_offsets: Union[int, torch.Tensor] = 0,
cu_seqlens: Optional[torch.Tensor] = None,
max_seqlen: Optional[int] = None,
):
"""
Arguments:
x: (batch_size, seqlen, nheads, headdim) if cu_seqlens is None
else (total_seqlen, nheads, headdim)
cos, sin: (seqlen_rotary, rotary_dim / 2)
interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead
of 1st half and 2nd half (GPT-NeoX style).
inplace: if True, apply rotary embedding in-place.
seqlen_offsets: (batch_size,) or int. Each sequence in x is shifted by this amount.
Most commonly used in inference when we have KV cache.
cu_seqlens: (batch + 1,) or None
max_seqlen: int
Return:
out: (batch_size, seqlen, nheads, headdim) if cu_seqlens is None
else (total_seqlen, nheads, headdim)
rotary_dim must be <= headdim
Apply rotary embedding to the first rotary_dim of x.
"""
# We are forcing the use of the pure PyTorch implementation (`apply_rotary_emb_torch`)
# for all devices. The custom Triton kernel (`ApplyRotaryEmb`) was causing a graph
# break in `torch.compile`, pushing expensive operations to the CPU.
# By using the pure PyTorch version, `torch.compile` can create a single, fully-optimized
# graph, which should resolve the CPU bottleneck and improve GPU utilization.
return apply_rotary_emb_torch(
x, cos, sin, interleaved, inplace, seqlen_offsets, cu_seqlens, max_seqlen
)