File size: 1,780 Bytes
6810eb1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 |
import torch
import torch.nn as nn
from model.layers import TransformerBlock
class GPTModel(nn.Module):
"""
GPT-style Language Model (decoder-only Transformer).
"""
def __init__(self, vocab_size: int, max_position_embeddings: int, n_layers: int,
n_heads: int, hidden_dim: int, dropout: float = 0.1):
super().__init__()
# Embedding layers
self.tok_embedding = nn.Embedding(vocab_size, hidden_dim)
self.pos_embedding = nn.Embedding(max_position_embeddings, hidden_dim)
self.dropout = nn.Dropout(dropout)
# Transformer blocks
self.layers = nn.ModuleList([
TransformerBlock(hidden_dim, n_heads, dropout) for _ in range(n_layers)
])
# Final layer normalization
self.ln_f = nn.LayerNorm(hidden_dim)
# Output projection to vocabulary size
self.output_proj = nn.Linear(hidden_dim, vocab_size, bias=False)
def forward(self, x):
"""
x: Tensor of token IDs with shape (batch_size, seq_length).
Returns: Logits of shape (batch_size, seq_length, vocab_size).
"""
batch_size, seq_length = x.shape
# Token and positional embeddings
tok_emb = self.tok_embedding(x) # (batch, seq_len, hidden_dim)
positions = torch.arange(0, seq_length, device=x.device).unsqueeze(0)
pos_emb = self.pos_embedding(positions) # (1, seq_len, hidden_dim)
h = self.dropout(tok_emb + pos_emb) # (batch, seq_len, hidden_dim)
# Transformer decoder blocks
for layer in self.layers:
h = layer(h)
# Final layer norm
h = self.ln_f(h)
# Compute logits
logits = self.output_proj(h) # (batch, seq_len, vocab_size)
return logits
|