|
from flash_attn import flash_attn_func |
|
import math |
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from einops import rearrange, repeat |
|
|
|
from .extact import xATGLU |
|
from .liger_rope import LigerRopeFunction |
|
from .config import LlamaConfig |
|
|
|
|
|
|
|
|
|
class DifferentialAttention(nn.Module): |
|
def __init__(self, config: LlamaConfig, layer_num): |
|
super().__init__() |
|
self.hidden_size = config.hidden_size |
|
self.num_heads = config.num_attention_heads |
|
self.num_kv_heads = config.num_key_value_heads |
|
self.n_rep = self.num_heads // self.num_kv_heads |
|
self.head_dim = self.hidden_size // (2 * self.num_heads) |
|
self.max_position_embeddings = config.max_position_embeddings |
|
self.rope_theta = config.rope_theta |
|
self.scaling = self.head_dim ** -0.5 |
|
|
|
self.q_proj = nn.Linear(self.hidden_size, 2 * self.num_heads * self.head_dim, bias=False) |
|
self.k_proj = nn.Linear(self.hidden_size, 2 * self.num_kv_heads * self.head_dim, bias=False) |
|
self.v_proj = nn.Linear(self.hidden_size, 2 * self.num_kv_heads * self.head_dim, bias=False) |
|
self.o_proj = nn.Linear(2 * self.num_heads * self.head_dim, self.hidden_size, bias=False) |
|
|
|
self.lambda_init = 0.8 - 0.6 * math.exp(-0.3 * layer_num) |
|
self.lambda_q1 = nn.Parameter(torch.zeros(self.head_dim).normal_(0, 0.1)) |
|
self.lambda_k1 = nn.Parameter(torch.zeros(self.head_dim).normal_(0, 0.1)) |
|
self.lambda_q2 = nn.Parameter(torch.zeros(self.head_dim).normal_(0, 0.1)) |
|
self.lambda_k2 = nn.Parameter(torch.zeros(self.head_dim).normal_(0, 0.1)) |
|
|
|
self.subln = nn.LayerNorm(2 * self.head_dim, elementwise_affine=False) |
|
|
|
self.register_buffer( |
|
"cos_cached", |
|
self._compute_rope_embeddings( |
|
self.max_position_embeddings, |
|
self.head_dim, |
|
self.rope_theta, |
|
dtype=torch.float32, |
|
device=self.q_proj.weight.device, |
|
)[0], |
|
persistent=False, |
|
) |
|
self.register_buffer( |
|
"sin_cached", |
|
self._compute_rope_embeddings( |
|
self.max_position_embeddings, |
|
self.head_dim, |
|
self.rope_theta, |
|
dtype=torch.float32, |
|
device=self.q_proj.weight.device, |
|
)[1], |
|
persistent=False, |
|
) |
|
|
|
def _compute_rope_embeddings(self, max_position_embeddings, head_dim, base=10000, dtype=None, device=None): |
|
inv_freq = 1.0 / (base ** (torch.arange(0, head_dim, 2, device=device).float() / head_dim)) |
|
t = torch.arange(max_position_embeddings, device=device, dtype=torch.float32) |
|
freqs = torch.einsum("i,j->ij", t, inv_freq) |
|
emb = torch.cat((freqs, freqs), dim=-1) |
|
cos = emb.cos().to(dtype) |
|
sin = emb.sin().to(dtype) |
|
return cos.unsqueeze(0), sin.unsqueeze(0) |
|
|
|
def forward( |
|
self, |
|
hidden_states, |
|
attention_mask, |
|
position_ids, |
|
) -> torch.Tensor: |
|
bsz, seq_len, embed_dim = hidden_states.size() |
|
|
|
if position_ids is None: |
|
position_ids = torch.arange(seq_len, device=hidden_states.device) |
|
position_ids = repeat(position_ids, 'l -> b l', b=bsz) |
|
|
|
q = self.q_proj(hidden_states) |
|
k = self.k_proj(hidden_states) |
|
v = self.v_proj(hidden_states) |
|
|
|
q = rearrange(q, 'b s (h d) -> b s h d', h=2*self.num_heads, d=self.head_dim) |
|
k = rearrange(k, 'b s (h d) -> b s h d', h=2*self.num_kv_heads, d=self.head_dim) |
|
|
|
|
|
v = rearrange(v, 'b s (h g d) -> b s h g d', h=self.num_kv_heads, g=2, d=self.head_dim) |
|
|
|
|
|
cos = self.cos_cached[:, position_ids] |
|
sin = self.sin_cached[:, position_ids] |
|
q, k = LigerRopeFunction.apply(q, k, cos, sin, position_ids) |
|
|
|
|
|
q = rearrange(q, 'b s (h g) d -> b s h g d', h=self.num_heads, g=2) |
|
k = rearrange(k, 'b s (h g) d -> b s h g d', h=self.num_kv_heads, g=2) |
|
|
|
q1, q2 = q[:, :, :, 0], q[:, :, :, 1] |
|
k1, k2 = k[:, :, :, 0], k[:, :, :, 1] |
|
v1, v2 = v[:, :, :, 0], v[:, :, :, 1] |
|
|
|
|
|
attn11 = flash_attn_func( |
|
q1, |
|
k1, |
|
v1, |
|
dropout_p=0.0, |
|
causal=attention_mask is None |
|
) |
|
attn12 = flash_attn_func( |
|
q1, |
|
k1, |
|
v2, |
|
dropout_p=0.0, |
|
causal=attention_mask is None |
|
) |
|
attn1 = torch.cat([attn11, attn12], dim=-1) |
|
|
|
|
|
attn21 = flash_attn_func( |
|
q2, |
|
k2, |
|
v1, |
|
dropout_p=0.0, |
|
causal=attention_mask is None |
|
) |
|
attn22 = flash_attn_func( |
|
q2, |
|
k2, |
|
v2, |
|
dropout_p=0.0, |
|
causal=attention_mask is None |
|
) |
|
attn2 = torch.cat([attn21, attn22], dim=-1) |
|
|
|
lambda_1 = torch.exp(torch.sum(self.lambda_q1 * self.lambda_k1, dim=-1).float()).type_as(q) |
|
lambda_2 = torch.exp(torch.sum(self.lambda_q2 * self.lambda_k2, dim=-1).float()).type_as(q) |
|
lambda_full = lambda_1 - lambda_2 + self.lambda_init |
|
attn = attn1 - lambda_full * attn2 |
|
|
|
attn = self.subln(attn) |
|
attn = attn * (1 - self.lambda_init) |
|
|
|
attn_output = rearrange(attn, "b s h d -> b s (h d)") |
|
return self.o_proj(attn_output) |