|
|
|
|
|
|
|
import torch |
|
import copy |
|
from einops import rearrange |
|
from flash_attn.layers.rotary import RotaryEmbedding |
|
from flash_attn.modules.mha import MHA |
|
|
|
|
|
|
|
class LinearlyScaledRotaryEmbedding(RotaryEmbedding): |
|
def __init__( |
|
self, |
|
dim: int, |
|
scaling_factor: float=1., |
|
base=10000.0, |
|
interleaved=False, |
|
scale_base=None, |
|
pos_idx_in_fp32=True, |
|
device=None, |
|
): |
|
super().__init__( |
|
dim=dim, |
|
base=base, |
|
interleaved=interleaved, |
|
scale_base=scale_base, |
|
pos_idx_in_fp32=pos_idx_in_fp32, |
|
device=device |
|
) |
|
self._linear_scaling_factor = scaling_factor |
|
|
|
def _update_cos_sin_cache(self, seqlen, device=None, dtype=None): |
|
|
|
|
|
|
|
if ( |
|
seqlen > self._seq_len_cached |
|
or self._cos_cached is None |
|
or self._cos_cached.device != device |
|
or self._cos_cached.dtype != dtype |
|
or (self.training and self._cos_cached.is_inference()) |
|
): |
|
self._seq_len_cached = seqlen |
|
|
|
|
|
|
|
if self.pos_idx_in_fp32: |
|
t = torch.arange(seqlen, device=device, dtype=torch.float32) |
|
|
|
t = t / self._linear_scaling_factor |
|
|
|
|
|
|
|
|
|
if self.inv_freq.dtype != torch.float32: |
|
inv_freq = self._compute_inv_freq(device=device) |
|
else: |
|
inv_freq = self.inv_freq |
|
else: |
|
t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype) |
|
|
|
t = t / self._linear_scaling_factor |
|
inv_freq = self.inv_freq |
|
|
|
|
|
freqs = torch.outer(t, inv_freq) |
|
if self.scale is None: |
|
self._cos_cached = torch.cos(freqs).to(dtype) |
|
self._sin_cached = torch.sin(freqs).to(dtype) |
|
else: |
|
power = ( |
|
torch.arange(seqlen, dtype=self.scale.dtype, device=self.scale.device) |
|
- seqlen // 2 |
|
) / self.scale_base |
|
scale = self.scale.to(device=power.device) ** rearrange(power, "s -> s 1") |
|
|
|
self._cos_cached = (torch.cos(freqs) * scale).to(dtype) |
|
self._sin_cached = (torch.sin(freqs) * scale).to(dtype) |
|
self._cos_k_cached = (torch.cos(freqs) / scale).to(dtype) |
|
self._sin_k_cached = (torch.sin(freqs) / scale).to(dtype) |
|
|
|
|
|
def swap_mha_rope( |
|
mha, |
|
new_rope: torch.nn.Module=LinearlyScaledRotaryEmbedding, |
|
kwargs_new_rope: dict=None |
|
): |
|
|
|
dtype = mha.Wq.weight.dtype if mha.cross_attn else mha.Wqkv.weight.dtype |
|
device = mha.Wq.weight.device if mha.cross_attn else mha.Wqkv.weight.device |
|
|
|
kwargs_old_rope = dict( |
|
dim = mha.rotary_emb.dim, |
|
base = mha.rotary_emb.base, |
|
interleaved = mha.rotary_emb.interleaved, |
|
scale_base = mha.rotary_emb.scale_base, |
|
pos_idx_in_fp32 = mha.rotary_emb.pos_idx_in_fp32, |
|
device = mha.rotary_emb.inv_freq.device |
|
) |
|
|
|
del mha.rotary_emb |
|
|
|
kwargs_new_rope = kwargs_new_rope or {'scaling_factor': 1.0} |
|
scaled_rope = new_rope( |
|
**kwargs_new_rope, |
|
**kwargs_old_rope |
|
).to(dtype) |
|
|
|
mha.rotary_emb = scaled_rope |
|
|
|
assert isinstance(mha.rotary_emb, new_rope) |
|
return mha |