Spaces:
Runtime error
Runtime error
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import math | |
| # ========================================== | |
| # MODEL CONFIG (Matching your 1.2M Llama) | |
| # ========================================== | |
| n_embd = 128 | |
| n_head = 4 | |
| n_layer = 6 | |
| block_size = 256 | |
| dropout = 0.2 | |
| # Tiny Shakespeare Vocab | |
| chars = ['\n', ' ', '!', '$', '&', "'", ',', '-', '.', '3', ':', ';', '?', 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z'] | |
| vocab_size = len(chars) | |
| stoi = { ch:i for i,ch in enumerate(chars) } | |
| itos = { i:ch for i,ch in enumerate(chars) } | |
| encode = lambda s: [stoi[c] for c in s if c in stoi] | |
| decode = lambda l: ''.join([itos[i] for i in l]) | |
| # ========================================== | |
| # HELPERS (RoPE & RMSNorm) | |
| # ========================================== | |
| def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0): | |
| freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) | |
| t = torch.arange(end, device=freqs.device) | |
| freqs = torch.outer(t, freqs).float() | |
| freqs_cis = torch.polar(torch.ones_like(freqs), freqs) | |
| return freqs_cis | |
| def apply_rotary_emb(xq, xk, freqs_cis): | |
| xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) | |
| xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) | |
| freqs_cis = freqs_cis.view(1, xq_.shape[1], 1, xq_.shape[-1]) | |
| xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) | |
| xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3) | |
| return xq_out.type_as(xq), xk_out.type_as(xk) | |
| class RMSNorm(nn.Module): | |
| def __init__(self, dim, eps=1e-5): | |
| super().__init__() | |
| self.eps = eps | |
| self.weight = nn.Parameter(torch.ones(dim)) | |
| def forward(self, x): | |
| x_normed = x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) | |
| return self.weight * x_normed | |
| # ========================================== | |
| # CORE LAYERS | |
| # ========================================== | |
| class SwiGLU(nn.Module): | |
| def __init__(self, dim): | |
| super().__init__() | |
| hidden_dim = int(8/3 * dim) | |
| hidden_dim = 4 * ((hidden_dim + 3) // 4) | |
| self.w1 = nn.Linear(dim, hidden_dim, bias=False) | |
| self.w2 = nn.Linear(hidden_dim, dim, bias=False) | |
| self.w3 = nn.Linear(dim, hidden_dim, bias=False) | |
| self.dropout = nn.Dropout(dropout) | |
| def forward(self, x): | |
| return self.dropout(self.w2(F.silu(self.w1(x)) * self.w3(x))) | |
| class CausalSelfAttention(nn.Module): | |
| def __init__(self): | |
| super().__init__() | |
| self.c_attn = nn.Linear(n_embd, 3 * n_embd, bias=False) | |
| self.c_proj = nn.Linear(n_embd, n_embd, bias=False) | |
| self.resid_dropout = nn.Dropout(dropout) | |
| def forward(self, x, freqs_cis): | |
| B, T, C = x.size() | |
| qkv = self.c_attn(x) | |
| q, k, v = qkv.split(n_embd, dim=2) | |
| k = k.view(B, T, n_head, C // n_head) | |
| q = q.view(B, T, n_head, C // n_head) | |
| v = v.view(B, T, n_head, C // n_head) | |
| q, k = apply_rotary_emb(q, k, freqs_cis) | |
| k, q, v = k.transpose(1, 2), q.transpose(1, 2), v.transpose(1, 2) | |
| y = F.scaled_dot_product_attention(q, k, v, is_causal=True, dropout_p=dropout if self.training else 0.0) | |
| y = y.transpose(1, 2).contiguous().view(B, T, C) | |
| return self.resid_dropout(self.c_proj(y)) | |
| class Block(nn.Module): | |
| def __init__(self): | |
| super().__init__() | |
| self.ln_1 = RMSNorm(n_embd) | |
| self.attn = CausalSelfAttention() | |
| self.ln_2 = RMSNorm(n_embd) | |
| self.ffwd = SwiGLU(n_embd) | |
| def forward(self, x, freqs_cis): | |
| x = x + self.attn(self.ln_1(x), freqs_cis) | |
| x = x + self.ffwd(self.ln_2(x)) | |
| return x | |
| # ========================================== | |
| # FINAL MODEL CLASS | |
| # ========================================== | |
| class LanguageModel(nn.Module): | |
| def __init__(self): | |
| super().__init__() | |
| self.token_embedding_table = nn.Embedding(vocab_size, n_embd) | |
| self.blocks = nn.ModuleList([Block() for _ in range(n_layer)]) | |
| self.ln_f = RMSNorm(n_embd) | |
| self.lm_head = nn.Linear(n_embd, vocab_size, bias=False) | |
| self.token_embedding_table.weight = self.lm_head.weight # Weight tying | |
| freqs_cis = precompute_freqs_cis(n_embd // n_head, block_size) | |
| self.register_buffer("freqs_cis", freqs_cis) | |
| self.apply(self._init_weights) | |
| def _init_weights(self, module): | |
| if isinstance(module, nn.Linear) or isinstance(module, nn.Embedding): | |
| torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) | |
| def forward(self, idx, targets=None): | |
| B, T = idx.shape | |
| x = self.token_embedding_table(idx) | |
| freqs_cis = self.freqs_cis[:T] | |
| for block in self.blocks: | |
| x = block(x, freqs_cis) | |
| x = self.ln_f(x) | |
| logits = self.lm_head(x) | |
| return logits, None | |
| def generate(self, idx, max_new_tokens): | |
| for _ in range(max_new_tokens): | |
| idx_cond = idx[:, -block_size:] | |
| logits, _ = self(idx_cond) | |
| logits = logits[:, -1, :] | |
| probs = F.softmax(logits, dim=-1) | |
| idx_next = torch.multinomial(probs, num_samples=1) | |
| idx = torch.cat((idx, idx_next), dim=1) | |
| return idx |