File size: 4,944 Bytes
92d683c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
# This software is distributed under the terms of the Apache License, Version 2.0
# Author: Armin Thomas, Eric Nguyen

import torch
import copy
from einops import rearrange
from flash_attn.layers.rotary import RotaryEmbedding
from flash_attn.modules.mha import MHA


# simple wrapper for flash-attn RoPE with linear scaling:
class LinearlyScaledRotaryEmbedding(RotaryEmbedding):
    def __init__(
        self,
        dim: int,
        scaling_factor: float=1.,
        base=10000.0,
        interleaved=False,
        scale_base=None,
        pos_idx_in_fp32=True,
        device=None,
    ):
        super().__init__(
            dim=dim,
            base=base,
            interleaved=interleaved,
            scale_base=scale_base,
            pos_idx_in_fp32=pos_idx_in_fp32,
            device=device
        )
        self._linear_scaling_factor = scaling_factor
    # adpated from: https://github.com/Dao-AILab/flash-attention/blob/43ceab630bc6c27712428da5a33fc9cb5c369d91/flash_attn/layers/rotary.py#L368
    def _update_cos_sin_cache(self, seqlen, device=None, dtype=None):
        # Reset the tables if the sequence length has changed,
        # if we're on a new device (possibly due to tracing for instance),
        # or if we're switching from inference mode to training
        if (
            seqlen > self._seq_len_cached
            or self._cos_cached is None
            or self._cos_cached.device != device
            or self._cos_cached.dtype != dtype
            or (self.training and self._cos_cached.is_inference())
        ):
            self._seq_len_cached = seqlen
            # We want fp32 here, not self.inv_freq.dtype, since the model could be loaded in bf16
            # And the output of arange can be quite large, so bf16 would lose a lot of precision.
            # However, for compatibility reason, we add an option to use the dtype of self.inv_freq.
            if self.pos_idx_in_fp32:
                t = torch.arange(seqlen, device=device, dtype=torch.float32)
                # linear scaling:
                t = t / self._linear_scaling_factor
                # We want fp32 here as well since inv_freq will be multiplied with t, and the output
                # will be large. Having it in bf16 will lose a lot of precision and cause the
                # cos & sin output to change significantly.
                # We want to recompute self.inv_freq if it was not loaded in fp32
                if self.inv_freq.dtype != torch.float32:
                    inv_freq = self._compute_inv_freq(device=device)
                else:
                    inv_freq = self.inv_freq
            else:
                t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype)
                # linear scaling:
                t = t / self._linear_scaling_factor
                inv_freq = self.inv_freq
            # Don't do einsum, it converts fp32 to fp16 under AMP
            # freqs = torch.einsum("i,j->ij", t, self.inv_freq)
            freqs = torch.outer(t, inv_freq)
            if self.scale is None:
                self._cos_cached = torch.cos(freqs).to(dtype)
                self._sin_cached = torch.sin(freqs).to(dtype)
            else:
                power = (
                    torch.arange(seqlen, dtype=self.scale.dtype, device=self.scale.device)
                    - seqlen // 2
                ) / self.scale_base
                scale = self.scale.to(device=power.device) ** rearrange(power, "s -> s 1")
                # We want the multiplication by scale to happen in fp32
                self._cos_cached = (torch.cos(freqs) * scale).to(dtype)
                self._sin_cached = (torch.sin(freqs) * scale).to(dtype)
                self._cos_k_cached = (torch.cos(freqs) / scale).to(dtype)
                self._sin_k_cached = (torch.sin(freqs) / scale).to(dtype)

# swap out RoPE of existing mha:
def swap_mha_rope(
    mha,
    new_rope: torch.nn.Module=LinearlyScaledRotaryEmbedding,
    kwargs_new_rope: dict=None
):
    # determine mha dtype and device:
    dtype = mha.Wq.weight.dtype if mha.cross_attn else mha.Wqkv.weight.dtype
    device = mha.Wq.weight.device if mha.cross_attn else mha.Wqkv.weight.device
    # determine RoPE settings:
    kwargs_old_rope = dict(
        dim = mha.rotary_emb.dim,
        base = mha.rotary_emb.base,
        interleaved = mha.rotary_emb.interleaved,
        scale_base = mha.rotary_emb.scale_base,
        pos_idx_in_fp32 = mha.rotary_emb.pos_idx_in_fp32,
        device = mha.rotary_emb.inv_freq.device
    )
    # delete old RoPE:
    del mha.rotary_emb
    # create new RoPE:
    kwargs_new_rope = kwargs_new_rope or {'scaling_factor': 1.0}
    scaled_rope = new_rope(
        **kwargs_new_rope,
        **kwargs_old_rope
    ).to(dtype)
    # attach new RoPE to mha:
    mha.rotary_emb = scaled_rope
    # make new sure RoPE is correctly registered:
    assert isinstance(mha.rotary_emb, new_rope)
    return mha