mtp1 / model.py
teszenofficial's picture
Create model.py
e398fee verified
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
class MultiHeadSelfAttention(nn.Module):
"""Multi-Head Self-Attention mechanism"""
def __init__(self, d_model, n_heads, dropout=0.1):
super().__init__()
assert d_model % n_heads == 0
self.d_model = d_model
self.n_heads = n_heads
self.d_k = d_model // n_heads
self.q_linear = nn.Linear(d_model, d_model)
self.k_linear = nn.Linear(d_model, d_model)
self.v_linear = nn.Linear(d_model, d_model)
self.out_linear = nn.Linear(d_model, d_model)
self.dropout = nn.Dropout(dropout)
def forward(self, x, mask=None):
batch_size, seq_len, d_model = x.size()
# Linear projections
Q = self.q_linear(x).view(batch_size, seq_len, self.n_heads, self.d_k).transpose(1, 2)
K = self.k_linear(x).view(batch_size, seq_len, self.n_heads, self.d_k).transpose(1, 2)
V = self.v_linear(x).view(batch_size, seq_len, self.n_heads, self.d_k).transpose(1, 2)
# Scaled dot-product attention
scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)
if mask is not None:
scores = scores.masked_fill(mask == 0, float('-inf'))
attn_weights = F.softmax(scores, dim=-1)
attn_weights = self.dropout(attn_weights)
context = torch.matmul(attn_weights, V)
context = context.transpose(1, 2).contiguous().view(batch_size, seq_len, d_model)
output = self.out_linear(context)
return output
class FeedForward(nn.Module):
"""Position-wise Feed-Forward Network"""
def __init__(self, d_model, d_ff, dropout=0.1):
super().__init__()
self.linear1 = nn.Linear(d_model, d_ff)
self.linear2 = nn.Linear(d_ff, d_model)
self.dropout = nn.Dropout(dropout)
def forward(self, x):
return self.linear2(self.dropout(F.gelu(self.linear1(x))))
class TransformerBlock(nn.Module):
"""Single Transformer Decoder Block"""
def __init__(self, d_model, n_heads, d_ff, dropout=0.1):
super().__init__()
self.attention = MultiHeadSelfAttention(d_model, n_heads, dropout)
self.feed_forward = FeedForward(d_model, d_ff, dropout)
self.ln1 = nn.LayerNorm(d_model)
self.ln2 = nn.LayerNorm(d_model)
self.dropout1 = nn.Dropout(dropout)
self.dropout2 = nn.Dropout(dropout)
def forward(self, x, mask=None):
# Self-attention with residual connection
attn_output = self.attention(self.ln1(x), mask)
x = x + self.dropout1(attn_output)
# Feed-forward with residual connection
ff_output = self.feed_forward(self.ln2(x))
x = x + self.dropout2(ff_output)
return x
class MTPMiniModel(nn.Module):
"""MTP Mini - GPT-style Transformer Language Model"""
def __init__(self, vocab_size, d_model=256, n_layers=4, n_heads=4,
d_ff=1024, max_seq_len=128, dropout=0.1):
super().__init__()
self.vocab_size = vocab_size
self.d_model = d_model
self.max_seq_len = max_seq_len
# Token embeddings
self.token_embedding = nn.Embedding(vocab_size, d_model)
# Positional embeddings (learnable)
self.position_embedding = nn.Embedding(max_seq_len, d_model)
# Transformer blocks
self.blocks = nn.ModuleList([
TransformerBlock(d_model, n_heads, d_ff, dropout)
for _ in range(n_layers)
])
# Final layer norm
self.ln_f = nn.LayerNorm(d_model)
# Output projection to vocabulary
self.lm_head = nn.Linear(d_model, vocab_size, bias=False)
# Weight tying
self.lm_head.weight = self.token_embedding.weight
self.dropout = nn.Dropout(dropout)
# Initialize weights
self.apply(self._init_weights)
def _init_weights(self, module):
if isinstance(module, nn.Linear):
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
if module.bias is not None:
torch.nn.init.zeros_(module.bias)
elif isinstance(module, nn.Embedding):
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
elif isinstance(module, nn.LayerNorm):
torch.nn.init.zeros_(module.bias)
torch.nn.init.ones_(module.weight)
def forward(self, input_ids, targets=None):
batch_size, seq_len = input_ids.size()
# Create causal mask
mask = torch.tril(torch.ones(seq_len, seq_len, device=input_ids.device)).view(1, 1, seq_len, seq_len)
# Token embeddings + positional embeddings
positions = torch.arange(0, seq_len, device=input_ids.device).unsqueeze(0)
tok_emb = self.token_embedding(input_ids)
pos_emb = self.position_embedding(positions)
x = self.dropout(tok_emb + pos_emb)
# Pass through transformer blocks
for block in self.blocks:
x = block(x, mask)
# Final layer norm
x = self.ln_f(x)
# Project to vocabulary
logits = self.lm_head(x)
# Calculate loss if targets provided
loss = None
if targets is not None:
loss = F.cross_entropy(logits.view(-1, self.vocab_size), targets.view(-1))
return logits, loss
def generate(self, input_ids, max_new_tokens=50, temperature=1.0, top_k=50, top_p=0.9):
"""Autoregressive generation with sampling"""
self.eval()
with torch.no_grad():
for _ in range(max_new_tokens):
# Crop to max_seq_len
input_ids_cond = input_ids if input_ids.size(1) <= self.max_seq_len else input_ids[:, -self.max_seq_len:]
# Forward pass
logits, _ = self(input_ids_cond)
logits = logits[:, -1, :] / temperature
# Top-k filtering
if top_k > 0:
v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
logits[logits < v[:, [-1]]] = float('-inf')
# Top-p (nucleus) filtering
if top_p < 1.0:
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
sorted_indices_to_remove = cumulative_probs > top_p
sorted_indices_to_remove[:, 1:] = sorted_indices_to_remove[:, :-1].clone()
sorted_indices_to_remove[:, 0] = 0
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
logits[indices_to_remove] = float('-inf')
# Sample from distribution
probs = F.softmax(logits, dim=-1)
next_token = torch.multinomial(probs, num_samples=1)
# Append to sequence
input_ids = torch.cat([input_ids, next_token], dim=1)
return input_ids
def count_parameters(self):
"""Count trainable parameters"""
return sum(p.numel() for p in self.parameters() if p.requires_grad)