|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from typing import Tuple |
|
|
|
import torch |
|
|
|
from open_lm.positional_embedding.rotary import apply_rotary_pos_emb, RotaryEmbedding |
|
|
|
|
|
class HeadRotaryEmbedding(RotaryEmbedding): |
|
""" |
|
The rotary position embeddings used in the first version of OpenLM. |
|
It is only kept for compatibility, RotaryEmbedding should be used instead. |
|
""" |
|
|
|
def __init__(self, dim_model: int, seq_len: int, *_, **__): |
|
super().__init__(dim_model, seq_len) |
|
self._has_warned = False |
|
|
|
def forward(self, q: torch.Tensor, k: torch.Tensor, offset: int = 0) -> Tuple[torch.Tensor, torch.Tensor]: |
|
self._update_cos_sin_tables(k.shape[2], device=k.device, dtype=k.dtype) |
|
|
|
if not self._has_warned and (offset != 0): |
|
print("Warning. HeadRotaryEmbedding does not support offset, I am not applying it.") |
|
self._has_warned = True |
|
|
|
out_q = apply_rotary_pos_emb(q.transpose(1, 2), self._cos_cached, self._sin_cached).transpose(1, 2) |
|
out_k = apply_rotary_pos_emb(k.transpose(1, 2), self._cos_cached, self._sin_cached).transpose(1, 2) |
|
return out_q, out_k |
|
|
|
|
|
class HeadRotaryWithCast(HeadRotaryEmbedding): |
|
|
|
def forward(self, q, k, v, offset: int = 0): |
|
q, k = super().forward(q, k, offset) |
|
return q.to(v.dtype), k.to(v.dtype), v |
|
|