Blackroot's picture
Upload 18 files
6aced58 verified
from flash_attn import flash_attn_func
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange, repeat
from .extact import xATGLU
from .liger_rope import LigerRopeFunction
from .config import LlamaConfig
# The four-flash attn strategy comes from here:
# https://github.com/microsoft/unilm/blob/master/Diff-Transformer/multihead_flashdiff_2.py
class DifferentialAttention(nn.Module):
def __init__(self, config: LlamaConfig, layer_num):
super().__init__()
self.hidden_size = config.hidden_size
self.num_heads = config.num_attention_heads
self.num_kv_heads = config.num_key_value_heads
self.n_rep = self.num_heads // self.num_kv_heads
self.head_dim = self.hidden_size // (2 * self.num_heads)
self.max_position_embeddings = config.max_position_embeddings
self.rope_theta = config.rope_theta
self.scaling = self.head_dim ** -0.5
self.q_proj = nn.Linear(self.hidden_size, 2 * self.num_heads * self.head_dim, bias=False)
self.k_proj = nn.Linear(self.hidden_size, 2 * self.num_kv_heads * self.head_dim, bias=False)
self.v_proj = nn.Linear(self.hidden_size, 2 * self.num_kv_heads * self.head_dim, bias=False)
self.o_proj = nn.Linear(2 * self.num_heads * self.head_dim, self.hidden_size, bias=False)
self.lambda_init = 0.8 - 0.6 * math.exp(-0.3 * layer_num)
self.lambda_q1 = nn.Parameter(torch.zeros(self.head_dim).normal_(0, 0.1))
self.lambda_k1 = nn.Parameter(torch.zeros(self.head_dim).normal_(0, 0.1))
self.lambda_q2 = nn.Parameter(torch.zeros(self.head_dim).normal_(0, 0.1))
self.lambda_k2 = nn.Parameter(torch.zeros(self.head_dim).normal_(0, 0.1))
self.subln = nn.LayerNorm(2 * self.head_dim, elementwise_affine=False)
self.register_buffer(
"cos_cached",
self._compute_rope_embeddings(
self.max_position_embeddings,
self.head_dim,
self.rope_theta,
dtype=torch.float32,
device=self.q_proj.weight.device,
)[0],
persistent=False,
)
self.register_buffer(
"sin_cached",
self._compute_rope_embeddings(
self.max_position_embeddings,
self.head_dim,
self.rope_theta,
dtype=torch.float32,
device=self.q_proj.weight.device,
)[1],
persistent=False,
)
def _compute_rope_embeddings(self, max_position_embeddings, head_dim, base=10000, dtype=None, device=None):
inv_freq = 1.0 / (base ** (torch.arange(0, head_dim, 2, device=device).float() / head_dim))
t = torch.arange(max_position_embeddings, device=device, dtype=torch.float32)
freqs = torch.einsum("i,j->ij", t, inv_freq)
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos().to(dtype)
sin = emb.sin().to(dtype)
return cos.unsqueeze(0), sin.unsqueeze(0)
def forward(
self,
hidden_states,
attention_mask,
position_ids,
) -> torch.Tensor:
bsz, seq_len, embed_dim = hidden_states.size()
if position_ids is None:
position_ids = torch.arange(seq_len, device=hidden_states.device)
position_ids = repeat(position_ids, 'l -> b l', b=bsz)
q = self.q_proj(hidden_states)
k = self.k_proj(hidden_states)
v = self.v_proj(hidden_states)
q = rearrange(q, 'b s (h d) -> b s h d', h=2*self.num_heads, d=self.head_dim)
k = rearrange(k, 'b s (h d) -> b s h d', h=2*self.num_kv_heads, d=self.head_dim)
# Reshaped for GQA
v = rearrange(v, 'b s (h g d) -> b s h g d', h=self.num_kv_heads, g=2, d=self.head_dim)
# Apply rotary embeddings using LigerRopeFunction
cos = self.cos_cached[:, position_ids] # [1, bsz, seq_len, dim]
sin = self.sin_cached[:, position_ids] # [1, bsz, seq_len, dim]
q, k = LigerRopeFunction.apply(q, k, cos, sin, position_ids)
# Rearrange into GQA style
q = rearrange(q, 'b s (h g) d -> b s h g d', h=self.num_heads, g=2)
k = rearrange(k, 'b s (h g) d -> b s h g d', h=self.num_kv_heads, g=2)
q1, q2 = q[:, :, :, 0], q[:, :, :, 1]
k1, k2 = k[:, :, :, 0], k[:, :, :, 1]
v1, v2 = v[:, :, :, 0], v[:, :, :, 1]
# First attention group on q1/k1 and the v's
attn11 = flash_attn_func(
q1,
k1,
v1,
dropout_p=0.0, # @Z TODO::
causal=attention_mask is None
)
attn12 = flash_attn_func(
q1,
k1,
v2,
dropout_p=0.0,
causal=attention_mask is None
)
attn1 = torch.cat([attn11, attn12], dim=-1)
# Second attention group on q2/k2 and the v's
attn21 = flash_attn_func(
q2,
k2,
v1,
dropout_p=0.0,
causal=attention_mask is None
)
attn22 = flash_attn_func(
q2,
k2,
v2,
dropout_p=0.0,
causal=attention_mask is None
)
attn2 = torch.cat([attn21, attn22], dim=-1)
lambda_1 = torch.exp(torch.sum(self.lambda_q1 * self.lambda_k1, dim=-1).float()).type_as(q)
lambda_2 = torch.exp(torch.sum(self.lambda_q2 * self.lambda_k2, dim=-1).float()).type_as(q)
lambda_full = lambda_1 - lambda_2 + self.lambda_init
attn = attn1 - lambda_full * attn2
attn = self.subln(attn)
attn = attn * (1 - self.lambda_init)
attn_output = rearrange(attn, "b s h d -> b s (h d)")
return self.o_proj(attn_output)