| """Shared model architecture for multilingual 3B GPT — must match training exactly.""" |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| import math |
|
|
| VOCAB_SIZE = 32000 |
| DIM = 3072 |
| DEPTH = 26 |
| N_HEADS = 24 |
| HEAD_DIM = DIM // N_HEADS |
| MAX_SEQ_LEN = 2048 |
| ROPE_THETA = 10000.0 |
| HIDDEN_DIM = ((int(2 * DIM * 4 / 3) + 63) // 64) * 64 |
|
|
|
|
| class RMSNorm(nn.Module): |
| def __init__(self, dim, eps=1e-6): |
| super().__init__() |
| self.eps = eps |
| self.weight = nn.Parameter(torch.ones(dim)) |
|
|
| def forward(self, x): |
| norm = x.float().pow(2).mean(-1, keepdim=True).add(self.eps).rsqrt() |
| return (x.float() * norm).type_as(x) * self.weight |
|
|
|
|
| def precompute_freqs_cis(dim, max_seq_len, theta=ROPE_THETA): |
| freqs = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) |
| t = torch.arange(max_seq_len, dtype=torch.float32) |
| freqs = torch.outer(t, freqs) |
| return torch.polar(torch.ones_like(freqs), freqs) |
|
|
|
|
| def apply_rotary_emb(x, freqs_cis): |
| |
| B, H, S, D = x.shape |
| x_complex = torch.view_as_complex(x.float().reshape(B, H, S, D // 2, 2)) |
| freqs = freqs_cis[:S].unsqueeze(0).unsqueeze(1) |
| x_rot = torch.view_as_real(x_complex * freqs).reshape(B, H, S, D) |
| return x_rot.type_as(x) |
|
|
|
|
| class FusedAttention(nn.Module): |
| def __init__(self, dim, n_heads): |
| super().__init__() |
| self.n_heads = n_heads |
| self.head_dim = dim // n_heads |
| self.qkv = nn.Linear(dim, 3 * dim, bias=False) |
| self.out_proj = nn.Linear(dim, dim, bias=False) |
|
|
| def forward(self, x, freqs_cis, mask=None): |
| B, S, D = x.shape |
| qkv = self.qkv(x).reshape(B, S, 3, self.n_heads, self.head_dim) |
| q, k, v = qkv[:, :, 0], qkv[:, :, 1], qkv[:, :, 2] |
| q = q.transpose(1, 2) |
| k = k.transpose(1, 2) |
| v = v.transpose(1, 2) |
| q = apply_rotary_emb(q, freqs_cis) |
| k = apply_rotary_emb(k, freqs_cis) |
| |
| scale = math.sqrt(self.head_dim) |
| attn = (q @ k.transpose(-2, -1)) / scale |
| if mask is not None: |
| attn = attn + mask |
| attn = F.softmax(attn, dim=-1) |
| out = (attn @ v).transpose(1, 2).reshape(B, S, D) |
| return self.out_proj(out) |
|
|
|
|
| class SwiGLUFFN(nn.Module): |
| def __init__(self, dim, hidden_dim): |
| super().__init__() |
| self.w1 = nn.Linear(dim, hidden_dim, bias=False) |
| self.w2 = nn.Linear(hidden_dim, dim, bias=False) |
| self.w3 = nn.Linear(dim, hidden_dim, bias=False) |
|
|
| def forward(self, x): |
| return self.w2(F.silu(self.w1(x)) * self.w3(x)) |
|
|
|
|
| class TransformerBlock(nn.Module): |
| def __init__(self, dim, n_heads, hidden_dim): |
| super().__init__() |
| self.attn_norm = RMSNorm(dim) |
| self.attn = FusedAttention(dim, n_heads) |
| self.ffn_norm = RMSNorm(dim) |
| self.ffn = SwiGLUFFN(dim, hidden_dim) |
|
|
| def forward(self, x, freqs_cis, mask=None): |
| x = x + self.attn(self.attn_norm(x), freqs_cis, mask) |
| x = x + self.ffn(self.ffn_norm(x)) |
| return x |
|
|
|
|
| class MultilingualGPT(nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.tok_emb = nn.Embedding(VOCAB_SIZE, DIM) |
| self.layers = nn.ModuleList([ |
| TransformerBlock(DIM, N_HEADS, HIDDEN_DIM) for _ in range(DEPTH) |
| ]) |
| self.norm = RMSNorm(DIM) |
| self.head = nn.Linear(DIM, VOCAB_SIZE, bias=False) |
| |
| self.head.weight = self.tok_emb.weight |
| |
| self.register_buffer('freqs_cis', precompute_freqs_cis(HEAD_DIM, MAX_SEQ_LEN)) |
|
|
| def forward(self, tokens, targets=None): |
| B, S = tokens.shape |
| x = self.tok_emb(tokens) |
| mask = torch.triu(torch.full((S, S), float('-inf'), device=tokens.device), diagonal=1) |
| mask = mask.unsqueeze(0).unsqueeze(0) |
| for layer in self.layers: |
| x = layer(x, self.freqs_cis, mask) |
| x = self.norm(x) |
| logits = self.head(x) |
| loss = None |
| if targets is not None: |
| loss = F.cross_entropy(logits.view(-1, VOCAB_SIZE), targets.view(-1)) |
| return logits, loss |
|
|
|
|
| def load_model(path, device='cuda'): |
| """Load model from checkpoint, stripping prefixes.""" |
| model = MultilingualGPT() |
| ckpt = torch.load(path, map_location='cpu', weights_only=False) |
| state = ckpt.get('model_state_dict', ckpt) |
| |
| cleaned = {} |
| for k, v in state.items(): |
| new_k = k |
| for prefix in ['_orig_mod.', 'module.']: |
| if new_k.startswith(prefix): |
| new_k = new_k[len(prefix):] |
| cleaned[new_k] = v |
| |
| if 'head.weight' in cleaned and 'tok_emb.weight' in cleaned: |
| if torch.equal(cleaned['head.weight'], cleaned['tok_emb.weight']): |
| del cleaned['head.weight'] |
| model.load_state_dict(cleaned, strict=False) |
| model = model.to(device).eval() |
| return model |
|
|
|
|
| def load_tokenizer(path): |
| """Load SentencePiece tokenizer.""" |
| import sentencepiece as spm |
| sp = spm.SentencePieceProcessor() |
| sp.Load(path) |
| return sp |
|
|