Spaces:
Sleeping
Sleeping
# Copyright (c) 2023, Tri Dao. | |
from typing import Optional, Union | |
import torch | |
def apply_rotary_emb_torch( | |
x, | |
cos, | |
sin, | |
interleaved=False, | |
inplace=False, | |
seqlen_offsets=0, | |
cu_seqlens=None, | |
max_seqlen=None, | |
): | |
# Only supports the basic (not interleaved, not variable-length) case. | |
rotary_dim = cos.shape[1] * 2 | |
x1 = x[..., :rotary_dim] | |
x2 = x[..., rotary_dim:] | |
# Split [even, odd] pairs | |
x1_1, x1_2 = x1[..., ::2], x1[..., 1::2] # (..., rotary_dim/2) | |
# Reshape cos/sin for broadcasting | |
# x: [batch, seqlen, nheads, rotary_dim] | |
# cos/sin: [seqlen, rotary_dim/2] | |
# reshape to [1, seqlen, 1, rotary_dim/2] to broadcast | |
cos = cos.unsqueeze(0).unsqueeze(2) | |
sin = sin.unsqueeze(0).unsqueeze(2) | |
rot_x1 = x1_1 * cos - x1_2 * sin | |
rot_x2 = x1_1 * sin + x1_2 * cos | |
# Interleave last dimension: (..., rotary_dim/2, 2) -> (..., rotary_dim) | |
rot_x = torch.stack([rot_x1, rot_x2], dim=-1).reshape_as(x1) | |
out = torch.cat([rot_x, x2], dim=-1) | |
return out | |
def apply_rotary_emb( | |
x, | |
cos, | |
sin, | |
interleaved=False, | |
inplace=False, | |
seqlen_offsets: Union[int, torch.Tensor] = 0, | |
cu_seqlens: Optional[torch.Tensor] = None, | |
max_seqlen: Optional[int] = None, | |
): | |
""" | |
Arguments: | |
x: (batch_size, seqlen, nheads, headdim) if cu_seqlens is None | |
else (total_seqlen, nheads, headdim) | |
cos, sin: (seqlen_rotary, rotary_dim / 2) | |
interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead | |
of 1st half and 2nd half (GPT-NeoX style). | |
inplace: if True, apply rotary embedding in-place. | |
seqlen_offsets: (batch_size,) or int. Each sequence in x is shifted by this amount. | |
Most commonly used in inference when we have KV cache. | |
cu_seqlens: (batch + 1,) or None | |
max_seqlen: int | |
Return: | |
out: (batch_size, seqlen, nheads, headdim) if cu_seqlens is None | |
else (total_seqlen, nheads, headdim) | |
rotary_dim must be <= headdim | |
Apply rotary embedding to the first rotary_dim of x. | |
""" | |
# We are forcing the use of the pure PyTorch implementation (`apply_rotary_emb_torch`) | |
# for all devices. The custom Triton kernel (`ApplyRotaryEmb`) was causing a graph | |
# break in `torch.compile`, pushing expensive operations to the CPU. | |
# By using the pure PyTorch version, `torch.compile` can create a single, fully-optimized | |
# graph, which should resolve the CPU bottleneck and improve GPU utilization. | |
return apply_rotary_emb_torch( | |
x, cos, sin, interleaved, inplace, seqlen_offsets, cu_seqlens, max_seqlen | |
) | |