File size: 5,955 Bytes
15b5ad2 35c01e4 d3a2364 9eb4614 35c01e4 15b5ad2 4350fdb 15b5ad2 7d20d96 15b5ad2 da94f40 9eb4614 eaf9a45 81d7d68 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 |
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.attention import sdpa_kernel, SDPBackend
from transformers import PreTrainedModel
from .configuration_custom_mbz_test import CustomConfig
from transformers.modeling_outputs import CausalLMOutput
class RotaryPositionalEncoding(nn.Module):
"""
Rotary Position Embeddings (RoPE) - efficient implementation
"""
def __init__(self, d_head, max_seq_len=8192, base=10000.0):
super().__init__()
self.d_head = d_head
self.max_seq_len = max_seq_len
self.base = base
# Precompute inverse frequencies
inv_freq = 1.0 / (base ** (torch.arange(0, d_head, 2).float() / d_head))
self.register_buffer('inv_freq', inv_freq, persistent=False)
# Precompute cos and sin for maximum sequence length
self._precompute_freqs(max_seq_len)
def _precompute_freqs(self, seq_len):
"""Precompute cos and sin values for positions"""
t = torch.arange(seq_len, dtype=self.inv_freq.dtype, device=self.inv_freq.device)
freqs = torch.outer(t, self.inv_freq) # (seq_len, d_head/2)
# Create cos and sin embeddings
freqs_cos = torch.cos(freqs)
freqs_sin = torch.sin(freqs)
# Interleave to match the dimension (seq_len, d_head)
self.register_buffer('freqs_cos', freqs_cos.repeat_interleave(2, dim=-1), persistent=False)
self.register_buffer('freqs_sin', freqs_sin.repeat_interleave(2, dim=-1), persistent=False)
def rotate_half(self, x):
"""Rotate half the hidden dims of the input"""
x1 = x[..., ::2]
x2 = x[..., 1::2]
return torch.stack([-x2, x1], dim=-1).flatten(-2)
def forward(self, q, k, start_pos=0):
"""
Apply rotary embeddings to query and key tensors
Args:
q: (batch_size, n_heads, seq_len, d_head)
k: (batch_size, n_heads, seq_len, d_head)
start_pos: starting position for caching scenarios
Returns:
q_rot, k_rot with rotary embeddings applied
"""
seq_len = q.shape[2]
# Get the precomputed frequencies for this sequence length
freqs_cos = self.freqs_cos[start_pos:start_pos + seq_len]
freqs_sin = self.freqs_sin[start_pos:start_pos + seq_len]
# Apply rotary embeddings
q_rot = q * freqs_cos + self.rotate_half(q) * freqs_sin
k_rot = k * freqs_cos + self.rotate_half(k) * freqs_sin
return q_rot, k_rot
class Attention(nn.Module):
def __init__(self, d_model, n_heads, d_head):
super().__init__()
self.d_model = d_model
self.n_heads = n_heads
self.d_head = d_head
self.Wq = nn.Linear(d_model, n_heads * d_head, bias=False)
self.Wk = nn.Linear(d_model, n_heads * d_head, bias=False)
self.Wv = nn.Linear(d_model, n_heads * d_head, bias=False)
self.Wo = nn.Linear(n_heads * d_head, d_model, bias=False)
# Initialize RoPE
self.rope = RotaryPositionalEncoding(d_head)
def forward(self, x):
# x is shape batch_size, seq_len, d_model
batch_size, seq_len, d_model = x.shape
q = self.Wq(x) # q is shape batch_size, seq_len, n_heads * d_head
k = self.Wk(x)
v = self.Wv(x)
# reshape to batch_size, n_heads, seq_len, d_head
q = q.reshape(batch_size, seq_len, self.n_heads, self.d_head).transpose(1,2)
k = k.reshape(batch_size, seq_len, self.n_heads, self.d_head).transpose(1,2)
v = v.reshape(batch_size, seq_len, self.n_heads, self.d_head).transpose(1,2)
q, k = self.rope(q, k)
with sdpa_kernel(SDPBackend.FLASH_ATTENTION): # ensure use flash attention
a = F.scaled_dot_product_attention(q, k, v, attn_mask=None, is_causal=True)# a is (batch_size, n_heads, seq_len, d_head)
a = a.transpose(1,2) # change a to (batch_size, seq_len, n_heads, d_head)
a = a.reshape(batch_size, seq_len, self.n_heads * self.d_head)
out = self.Wo(a) # out is (batch_size, seq_len, d_model)
return out
class TransformerBlock(nn.Module):
def __init__(self, d_model, n_heads, d_head):
super().__init__()
self.d_model = d_model
self.n_heads = n_heads
self.d_head = d_head
self.attn = Attention(d_model, n_heads, d_head)
self.mlp = nn.Sequential(nn.Linear(d_model, 4*d_model), nn.ReLU(), nn.Linear(4*d_model, d_model))
self.norm1 = nn.RMSNorm(d_model)
self.norm2 = nn.RMSNorm(d_model)
def forward(self, x):
x = self.attn(self.norm1(x)) + x
x = self.mlp(self.norm2(x)) + x
return x
class GPT(nn.Module):
def __init__(self, d_model, n_heads, d_head, n_vocab, n_layers):
super().__init__()
self.d_model = d_model
self.n_heads = n_heads
self.d_head = d_head
self.n_vocab = n_vocab
self.embed = nn.Embedding(n_vocab, d_model)
self.blocks = nn.ModuleList([TransformerBlock(d_model, n_heads, d_head) for _ in range(n_layers)])
self.norm = nn.RMSNorm(d_model)
self.out_head = nn.Linear(d_model, n_vocab)
def forward(self, x):
x = self.embed(x)
for block in self.blocks:
x = block(x)
x = self.out_head(self.norm(x))
return x
class CustomModelForCausalLM(PreTrainedModel):
config_class = CustomConfig
_supports_attention_backend = True
def __init__(self, config):
super().__init__(config)
self.model = GPT(config.d_model, config.n_heads, config.d_head, config.n_vocab, config.n_layers)
def forward(self, tensor):
with torch.autocast('cuda', dtype=torch.bfloat16):
logits = self.model(tensor)
return CausalLMOutput(logits=logits)
def get_input_embeddings(self):
return self.model.embed
def set_input_embeddings(self, x):
self.model.embed = x |