# Copyright (c) 2023, Tri Dao. from typing import Optional, Union import torch def apply_rotary_emb_torch( x, cos, sin, interleaved=False, inplace=False, seqlen_offsets=0, cu_seqlens=None, max_seqlen=None, ): # Only supports the basic (not interleaved, not variable-length) case. rotary_dim = cos.shape[1] * 2 x1 = x[..., :rotary_dim] x2 = x[..., rotary_dim:] # Split [even, odd] pairs x1_1, x1_2 = x1[..., ::2], x1[..., 1::2] # (..., rotary_dim/2) # Reshape cos/sin for broadcasting # x: [batch, seqlen, nheads, rotary_dim] # cos/sin: [seqlen, rotary_dim/2] # reshape to [1, seqlen, 1, rotary_dim/2] to broadcast cos = cos.unsqueeze(0).unsqueeze(2) sin = sin.unsqueeze(0).unsqueeze(2) rot_x1 = x1_1 * cos - x1_2 * sin rot_x2 = x1_1 * sin + x1_2 * cos # Interleave last dimension: (..., rotary_dim/2, 2) -> (..., rotary_dim) rot_x = torch.stack([rot_x1, rot_x2], dim=-1).reshape_as(x1) out = torch.cat([rot_x, x2], dim=-1) return out def apply_rotary_emb( x, cos, sin, interleaved=False, inplace=False, seqlen_offsets: Union[int, torch.Tensor] = 0, cu_seqlens: Optional[torch.Tensor] = None, max_seqlen: Optional[int] = None, ): """ Arguments: x: (batch_size, seqlen, nheads, headdim) if cu_seqlens is None else (total_seqlen, nheads, headdim) cos, sin: (seqlen_rotary, rotary_dim / 2) interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead of 1st half and 2nd half (GPT-NeoX style). inplace: if True, apply rotary embedding in-place. seqlen_offsets: (batch_size,) or int. Each sequence in x is shifted by this amount. Most commonly used in inference when we have KV cache. cu_seqlens: (batch + 1,) or None max_seqlen: int Return: out: (batch_size, seqlen, nheads, headdim) if cu_seqlens is None else (total_seqlen, nheads, headdim) rotary_dim must be <= headdim Apply rotary embedding to the first rotary_dim of x. """ # We are forcing the use of the pure PyTorch implementation (`apply_rotary_emb_torch`) # for all devices. The custom Triton kernel (`ApplyRotaryEmb`) was causing a graph # break in `torch.compile`, pushing expensive operations to the CPU. # By using the pure PyTorch version, `torch.compile` can create a single, fully-optimized # graph, which should resolve the CPU bottleneck and improve GPU utilization. return apply_rotary_emb_torch( x, cos, sin, interleaved, inplace, seqlen_offsets, cu_seqlens, max_seqlen )