|
import torch |
|
import math |
|
from torch import nn |
|
import torch.nn.functional as F |
|
import einops |
|
from rotary_embedding_torch import RotaryEmbedding |
|
|
|
class TransformerEncoder(torch.nn.Module): |
|
""" |
|
Single Transformer Encoder. |
|
|
|
""" |
|
def __init__( |
|
self, |
|
hidden_embed_size, |
|
n_attn_heads, |
|
attn_dropout: float = 0.0, |
|
layer_norm_eps: float = 1e-05, |
|
a_fn: str = "gelu", |
|
): |
|
super().__init__() |
|
|
|
assert hidden_embed_size % n_attn_heads == 0, \ |
|
"Embedding dimension must be devisible with the number of heads." |
|
|
|
self.multihead_attention = MultiHeadAttention( |
|
embed_dim = hidden_embed_size, |
|
num_heads = n_attn_heads, |
|
attention_dropout_prob = attn_dropout |
|
) |
|
|
|
activation_fn, scale = get_activation_fn(a_fn) |
|
|
|
self.intermediate_layer = torch.nn.Sequential( |
|
torch.nn.Linear(hidden_embed_size, hidden_embed_size * 4 * scale), |
|
activation_fn(), |
|
torch.nn.Linear(hidden_embed_size * 4, hidden_embed_size), |
|
) |
|
|
|
self.pre_attn_layer_norm = torch.nn.LayerNorm(hidden_embed_size, eps=layer_norm_eps) |
|
self.final_layer_norm = torch.nn.LayerNorm(hidden_embed_size, eps=layer_norm_eps) |
|
|
|
def forward(self, hidden_embed, attn_mask=None, return_attn_weights: bool = False): |
|
|
|
residual = hidden_embed |
|
hidden_embed = self.pre_attn_layer_norm(hidden_embed.clone()) |
|
hidden_embed, attn_weights = self.multihead_attention( |
|
hidden_embed, |
|
attn_mask=attn_mask, |
|
return_attn_weights=return_attn_weights |
|
) |
|
hidden_embed = residual + hidden_embed |
|
|
|
residual = hidden_embed |
|
hidden_embed = self.final_layer_norm(hidden_embed) |
|
hidden_embed = self.intermediate_layer(hidden_embed) |
|
hidden_embed = residual + hidden_embed |
|
return hidden_embed, attn_weights |
|
|
|
class MultiHeadAttention(torch.nn.Module): |
|
|
|
def __init__( |
|
self, |
|
embed_dim, |
|
num_heads, |
|
attention_dropout_prob: float = 0.0, |
|
bias: bool = True, |
|
): |
|
super().__init__() |
|
|
|
self.attention_dropout = torch.nn.Dropout(attention_dropout_prob) |
|
|
|
self.embed_dim = embed_dim |
|
self.num_heads = num_heads |
|
self.head_dim = embed_dim // num_heads |
|
assert (self.head_dim * num_heads == self.embed_dim), "embed_dim must be divisible by num_heads" |
|
self.scaling = self.head_dim**-0.5 |
|
|
|
self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias) |
|
self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias) |
|
self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) |
|
|
|
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) |
|
|
|
self.reset_parameters() |
|
|
|
self.rotary_emb = RotaryEmbedding(dim = self.head_dim) |
|
|
|
def reset_parameters(self): |
|
|
|
nn.init.xavier_uniform_(self.k_proj.weight, gain=1 / math.sqrt(2)) |
|
nn.init.xavier_uniform_(self.v_proj.weight, gain=1 / math.sqrt(2)) |
|
nn.init.xavier_uniform_(self.q_proj.weight, gain=1 / math.sqrt(2)) |
|
|
|
nn.init.xavier_uniform_(self.out_proj.weight) |
|
if self.out_proj.bias is not None: |
|
nn.init.constant_(self.out_proj.bias, 0.0) |
|
|
|
def attention(self, q, k, v, attn_mask=None): |
|
|
|
attn_weights = torch.matmul(q, k.transpose(-2, -1)) |
|
attn_weights = attn_weights / math.sqrt(self.head_dim) |
|
|
|
if attn_mask is not None: |
|
attn_mask = einops.rearrange( |
|
attn_mask, |
|
'b_size (h1 h2 seq_len) -> b_size h1 h2 seq_len', |
|
h1=1, h2=1 |
|
) |
|
attn_weights = attn_weights.masked_fill(attn_mask, float("-inf")) |
|
|
|
attn_weights = F.softmax(attn_weights, dim=-1) |
|
|
|
attn = self.attention_dropout(attn_weights) |
|
attn = torch.matmul(attn, v) |
|
return attn, attn_weights |
|
|
|
def forward(self, x, attn_mask=None, return_attn_weights: bool = False): |
|
|
|
batch_size, seq_len, embed_dim = x.size() |
|
|
|
q, k, v = self.q_proj(x), self.k_proj(x), self.v_proj(x) |
|
q *= self.scaling |
|
|
|
q = q.contiguous().view( |
|
batch_size, |
|
seq_len, |
|
self.num_heads, |
|
self.head_dim |
|
).transpose(1, 2) |
|
k = k.contiguous().view( |
|
batch_size, |
|
seq_len, |
|
self.num_heads, |
|
self.head_dim |
|
).transpose(1, 2) |
|
v = v.contiguous().view( |
|
batch_size, |
|
seq_len, |
|
self.num_heads, |
|
self.head_dim |
|
).transpose(1, 2) |
|
|
|
q = self.rotary_emb.rotate_queries_or_keys(q) |
|
k = self.rotary_emb.rotate_queries_or_keys(k) |
|
|
|
|
|
attn, attn_weights = self.attention( |
|
q, k, v, |
|
attn_mask=attn_mask |
|
) |
|
|
|
attn = attn.transpose(1, 2).reshape(batch_size, seq_len, embed_dim) |
|
attn = self.out_proj(attn) |
|
|
|
if return_attn_weights: |
|
return attn, attn_weights |
|
else: |
|
return attn, None |
|
|
|
class SwiGLU(torch.nn.Module): |
|
def forward(self, x): |
|
x, gate = x.chunk(2, dim=-1) |
|
return F.silu(gate) * x |
|
|
|
def get_activation_fn(a_fn): |
|
|
|
if a_fn == "gelu": |
|
return torch.nn.GELU, 1 |
|
|
|
elif a_fn == "swiglu": |
|
return SwiGLU, 2 |
|
|