File size: 2,681 Bytes
adf0368
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
# Copyright (c) 2023, Tri Dao.

from typing import Optional, Union

import torch


def apply_rotary_emb_torch(
    x,
    cos,
    sin,
    interleaved=False,
    inplace=False,
    seqlen_offsets=0,
    cu_seqlens=None,
    max_seqlen=None,
):
    # Only supports the basic (not interleaved, not variable-length) case.
    rotary_dim = cos.shape[1] * 2
    x1 = x[..., :rotary_dim]
    x2 = x[..., rotary_dim:]

    # Split [even, odd] pairs
    x1_1, x1_2 = x1[..., ::2], x1[..., 1::2]  # (..., rotary_dim/2)

    # Reshape cos/sin for broadcasting
    # x: [batch, seqlen, nheads, rotary_dim]
    # cos/sin: [seqlen, rotary_dim/2]
    # reshape to [1, seqlen, 1, rotary_dim/2] to broadcast
    cos = cos.unsqueeze(0).unsqueeze(2)
    sin = sin.unsqueeze(0).unsqueeze(2)

    rot_x1 = x1_1 * cos - x1_2 * sin
    rot_x2 = x1_1 * sin + x1_2 * cos
    # Interleave last dimension: (..., rotary_dim/2, 2) -> (..., rotary_dim)
    rot_x = torch.stack([rot_x1, rot_x2], dim=-1).reshape_as(x1)
    out = torch.cat([rot_x, x2], dim=-1)
    return out


def apply_rotary_emb(
    x,
    cos,
    sin,
    interleaved=False,
    inplace=False,
    seqlen_offsets: Union[int, torch.Tensor] = 0,
    cu_seqlens: Optional[torch.Tensor] = None,
    max_seqlen: Optional[int] = None,
):
    """
    Arguments:
        x: (batch_size, seqlen, nheads, headdim) if cu_seqlens is None
            else (total_seqlen, nheads, headdim)
        cos, sin: (seqlen_rotary, rotary_dim / 2)
        interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead
            of 1st half and 2nd half (GPT-NeoX style).
        inplace: if True, apply rotary embedding in-place.
        seqlen_offsets: (batch_size,) or int. Each sequence in x is shifted by this amount.
            Most commonly used in inference when we have KV cache.
        cu_seqlens: (batch + 1,) or None
        max_seqlen: int
    Return:
        out: (batch_size, seqlen, nheads, headdim) if cu_seqlens is None
            else (total_seqlen, nheads, headdim)
    rotary_dim must be <= headdim
    Apply rotary embedding to the first rotary_dim of x.
    """
    # We are forcing the use of the pure PyTorch implementation (`apply_rotary_emb_torch`)
    # for all devices. The custom Triton kernel (`ApplyRotaryEmb`) was causing a graph
    # break in `torch.compile`, pushing expensive operations to the CPU.
    # By using the pure PyTorch version, `torch.compile` can create a single, fully-optimized
    # graph, which should resolve the CPU bottleneck and improve GPU utilization.
    return apply_rotary_emb_torch(
        x, cos, sin, interleaved, inplace, seqlen_offsets, cu_seqlens, max_seqlen
    )