|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from torch.nn.attention import sdpa_kernel, SDPBackend |
|
from transformers import PreTrainedModel |
|
from .configuration_custom_mbz_test import CustomConfig |
|
from transformers.modeling_outputs import CausalLMOutput |
|
|
|
|
|
class RotaryPositionalEncoding(nn.Module): |
|
""" |
|
Rotary Position Embeddings (RoPE) - efficient implementation |
|
""" |
|
def __init__(self, d_head, max_seq_len=8192, base=10000.0): |
|
super().__init__() |
|
self.d_head = d_head |
|
self.max_seq_len = max_seq_len |
|
self.base = base |
|
|
|
|
|
inv_freq = 1.0 / (base ** (torch.arange(0, d_head, 2).float() / d_head)) |
|
self.register_buffer('inv_freq', inv_freq, persistent=False) |
|
|
|
|
|
self._precompute_freqs(max_seq_len) |
|
|
|
def _precompute_freqs(self, seq_len): |
|
"""Precompute cos and sin values for positions""" |
|
t = torch.arange(seq_len, dtype=self.inv_freq.dtype, device=self.inv_freq.device) |
|
freqs = torch.outer(t, self.inv_freq) |
|
|
|
|
|
freqs_cos = torch.cos(freqs) |
|
freqs_sin = torch.sin(freqs) |
|
|
|
|
|
self.register_buffer('freqs_cos', freqs_cos.repeat_interleave(2, dim=-1), persistent=False) |
|
self.register_buffer('freqs_sin', freqs_sin.repeat_interleave(2, dim=-1), persistent=False) |
|
|
|
def rotate_half(self, x): |
|
"""Rotate half the hidden dims of the input""" |
|
x1 = x[..., ::2] |
|
x2 = x[..., 1::2] |
|
return torch.stack([-x2, x1], dim=-1).flatten(-2) |
|
|
|
def forward(self, q, k, start_pos=0): |
|
""" |
|
Apply rotary embeddings to query and key tensors |
|
Args: |
|
q: (batch_size, n_heads, seq_len, d_head) |
|
k: (batch_size, n_heads, seq_len, d_head) |
|
start_pos: starting position for caching scenarios |
|
Returns: |
|
q_rot, k_rot with rotary embeddings applied |
|
""" |
|
seq_len = q.shape[2] |
|
|
|
|
|
freqs_cos = self.freqs_cos[start_pos:start_pos + seq_len] |
|
freqs_sin = self.freqs_sin[start_pos:start_pos + seq_len] |
|
|
|
|
|
q_rot = q * freqs_cos + self.rotate_half(q) * freqs_sin |
|
k_rot = k * freqs_cos + self.rotate_half(k) * freqs_sin |
|
|
|
return q_rot, k_rot |
|
|
|
class Attention(nn.Module): |
|
def __init__(self, d_model, n_heads, d_head): |
|
super().__init__() |
|
self.d_model = d_model |
|
self.n_heads = n_heads |
|
self.d_head = d_head |
|
|
|
self.Wq = nn.Linear(d_model, n_heads * d_head, bias=False) |
|
self.Wk = nn.Linear(d_model, n_heads * d_head, bias=False) |
|
self.Wv = nn.Linear(d_model, n_heads * d_head, bias=False) |
|
self.Wo = nn.Linear(n_heads * d_head, d_model, bias=False) |
|
|
|
|
|
self.rope = RotaryPositionalEncoding(d_head) |
|
|
|
def forward(self, x): |
|
|
|
batch_size, seq_len, d_model = x.shape |
|
q = self.Wq(x) |
|
k = self.Wk(x) |
|
v = self.Wv(x) |
|
|
|
|
|
q = q.reshape(batch_size, seq_len, self.n_heads, self.d_head).transpose(1,2) |
|
k = k.reshape(batch_size, seq_len, self.n_heads, self.d_head).transpose(1,2) |
|
v = v.reshape(batch_size, seq_len, self.n_heads, self.d_head).transpose(1,2) |
|
|
|
q, k = self.rope(q, k) |
|
with sdpa_kernel(SDPBackend.FLASH_ATTENTION): |
|
a = F.scaled_dot_product_attention(q, k, v, attn_mask=None, is_causal=True) |
|
a = a.transpose(1,2) |
|
a = a.reshape(batch_size, seq_len, self.n_heads * self.d_head) |
|
out = self.Wo(a) |
|
return out |
|
|
|
class TransformerBlock(nn.Module): |
|
def __init__(self, d_model, n_heads, d_head): |
|
super().__init__() |
|
self.d_model = d_model |
|
self.n_heads = n_heads |
|
self.d_head = d_head |
|
|
|
self.attn = Attention(d_model, n_heads, d_head) |
|
self.mlp = nn.Sequential(nn.Linear(d_model, 4*d_model), nn.ReLU(), nn.Linear(4*d_model, d_model)) |
|
|
|
self.norm1 = nn.RMSNorm(d_model) |
|
self.norm2 = nn.RMSNorm(d_model) |
|
|
|
def forward(self, x): |
|
x = self.attn(self.norm1(x)) + x |
|
x = self.mlp(self.norm2(x)) + x |
|
return x |
|
|
|
class GPT(nn.Module): |
|
def __init__(self, d_model, n_heads, d_head, n_vocab, n_layers): |
|
super().__init__() |
|
self.d_model = d_model |
|
self.n_heads = n_heads |
|
self.d_head = d_head |
|
self.n_vocab = n_vocab |
|
|
|
self.embed = nn.Embedding(n_vocab, d_model) |
|
|
|
self.blocks = nn.ModuleList([TransformerBlock(d_model, n_heads, d_head) for _ in range(n_layers)]) |
|
|
|
self.norm = nn.RMSNorm(d_model) |
|
self.out_head = nn.Linear(d_model, n_vocab) |
|
|
|
def forward(self, x): |
|
x = self.embed(x) |
|
for block in self.blocks: |
|
x = block(x) |
|
x = self.out_head(self.norm(x)) |
|
return x |
|
|
|
class CustomModelForCausalLM(PreTrainedModel): |
|
config_class = CustomConfig |
|
_supports_attention_backend = True |
|
|
|
def __init__(self, config): |
|
super().__init__(config) |
|
self.model = GPT(config.d_model, config.n_heads, config.d_head, config.n_vocab, config.n_layers) |
|
|
|
def forward(self, tensor): |
|
with torch.autocast('cuda', dtype=torch.bfloat16): |
|
logits = self.model(tensor) |
|
return CausalLMOutput(logits=logits) |
|
|
|
def get_input_embeddings(self): |
|
return self.model.embed |
|
|
|
def set_input_embeddings(self, x): |
|
self.model.embed = x |