tiny-llm-cli-sft / model.py
jonmabe's picture
Upload model.py with huggingface_hub
e70d699 verified
"""
Tiny Transformer with modern components:
- RoPE (Rotary Position Embeddings)
- RMSNorm
- SwiGLU activation
- Weight tying
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
class RMSNorm(nn.Module):
def __init__(self, dim: int, eps: float = 1e-6):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim))
def forward(self, x):
norm = torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
return x * norm * self.weight
class RotaryEmbedding(nn.Module):
def __init__(self, dim: int, max_seq_len: int = 512, base: int = 10000):
super().__init__()
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
self.register_buffer("inv_freq", inv_freq)
self.max_seq_len = max_seq_len
def forward(self, x, seq_len: int):
t = torch.arange(seq_len, device=x.device).type_as(self.inv_freq)
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
emb = torch.cat((freqs, freqs), dim=-1)
return emb.cos(), emb.sin()
def rotate_half(x):
x1, x2 = x.chunk(2, dim=-1)
return torch.cat((-x2, x1), dim=-1)
def apply_rotary_pos_emb(q, k, cos, sin):
cos = cos.unsqueeze(0).unsqueeze(0) # [1, 1, seq_len, dim]
sin = sin.unsqueeze(0).unsqueeze(0)
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed
class SwiGLU(nn.Module):
def __init__(self, hidden_size: int, intermediate_size: int):
super().__init__()
self.w1 = nn.Linear(hidden_size, intermediate_size, bias=False)
self.w2 = nn.Linear(intermediate_size, hidden_size, bias=False)
self.w3 = nn.Linear(hidden_size, intermediate_size, bias=False)
def forward(self, x):
return self.w2(F.silu(self.w1(x)) * self.w3(x))
class Attention(nn.Module):
def __init__(self, hidden_size: int, num_heads: int, dropout: float = 0.0):
super().__init__()
self.num_heads = num_heads
self.head_dim = hidden_size // num_heads
self.q_proj = nn.Linear(hidden_size, hidden_size, bias=False)
self.k_proj = nn.Linear(hidden_size, hidden_size, bias=False)
self.v_proj = nn.Linear(hidden_size, hidden_size, bias=False)
self.o_proj = nn.Linear(hidden_size, hidden_size, bias=False)
self.rotary = RotaryEmbedding(self.head_dim)
self.dropout = nn.Dropout(dropout)
def forward(self, x, mask=None):
B, T, C = x.shape
q = self.q_proj(x).view(B, T, self.num_heads, self.head_dim).transpose(1, 2)
k = self.k_proj(x).view(B, T, self.num_heads, self.head_dim).transpose(1, 2)
v = self.v_proj(x).view(B, T, self.num_heads, self.head_dim).transpose(1, 2)
cos, sin = self.rotary(x, T)
q, k = apply_rotary_pos_emb(q, k, cos, sin)
# Scaled dot-product attention
scale = 1.0 / math.sqrt(self.head_dim)
attn = torch.matmul(q, k.transpose(-2, -1)) * scale
if mask is not None:
attn = attn.masked_fill(mask == 0, float('-inf'))
attn = F.softmax(attn, dim=-1)
attn = self.dropout(attn)
out = torch.matmul(attn, v)
out = out.transpose(1, 2).contiguous().view(B, T, C)
return self.o_proj(out)
class TransformerBlock(nn.Module):
def __init__(self, hidden_size: int, num_heads: int, intermediate_size: int, dropout: float = 0.0):
super().__init__()
self.norm1 = RMSNorm(hidden_size)
self.attn = Attention(hidden_size, num_heads, dropout)
self.norm2 = RMSNorm(hidden_size)
self.ffn = SwiGLU(hidden_size, intermediate_size)
def forward(self, x, mask=None):
x = x + self.attn(self.norm1(x), mask)
x = x + self.ffn(self.norm2(x))
return x
class TinyLLM(nn.Module):
def __init__(
self,
vocab_size: int = 32000,
hidden_size: int = 512,
num_layers: int = 12,
num_heads: int = 8,
intermediate_size: int = 1408,
max_position_embeddings: int = 512,
dropout: float = 0.0,
tie_weights: bool = True,
):
super().__init__()
self.vocab_size = vocab_size
self.hidden_size = hidden_size
self.embed_tokens = nn.Embedding(vocab_size, hidden_size)
self.layers = nn.ModuleList([
TransformerBlock(hidden_size, num_heads, intermediate_size, dropout)
for _ in range(num_layers)
])
self.norm = RMSNorm(hidden_size)
self.lm_head = nn.Linear(hidden_size, vocab_size, bias=False)
if tie_weights:
self.lm_head.weight = self.embed_tokens.weight
# Causal mask
self.register_buffer(
"causal_mask",
torch.tril(torch.ones(max_position_embeddings, max_position_embeddings))
)
self._init_weights()
def _init_weights(self):
for module in self.modules():
if isinstance(module, nn.Linear):
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
elif isinstance(module, nn.Embedding):
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
def forward(self, input_ids, labels=None):
B, T = input_ids.shape
x = self.embed_tokens(input_ids)
mask = self.causal_mask[:T, :T]
for layer in self.layers:
x = layer(x, mask)
x = self.norm(x)
logits = self.lm_head(x)
loss = None
if labels is not None:
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
loss = F.cross_entropy(
shift_logits.view(-1, self.vocab_size),
shift_labels.view(-1),
ignore_index=-100
)
return {"loss": loss, "logits": logits}
def count_parameters(self):
return sum(p.numel() for p in self.parameters())
if __name__ == "__main__":
# Test model
model = TinyLLM()
print(f"Parameters: {model.count_parameters() / 1e6:.2f}M")
x = torch.randint(0, 32000, (2, 128))
out = model(x, labels=x)
print(f"Loss: {out['loss'].item():.4f}")
print(f"Logits shape: {out['logits'].shape}")