File size: 1,780 Bytes
6810eb1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
import torch
import torch.nn as nn
from model.layers import TransformerBlock


class GPTModel(nn.Module):
    """
    GPT-style Language Model (decoder-only Transformer).
    """

    def __init__(self, vocab_size: int, max_position_embeddings: int, n_layers: int,
                 n_heads: int, hidden_dim: int, dropout: float = 0.1):
        super().__init__()
        # Embedding layers
        self.tok_embedding = nn.Embedding(vocab_size, hidden_dim)
        self.pos_embedding = nn.Embedding(max_position_embeddings, hidden_dim)
        self.dropout = nn.Dropout(dropout)
        # Transformer blocks
        self.layers = nn.ModuleList([
            TransformerBlock(hidden_dim, n_heads, dropout) for _ in range(n_layers)
        ])
        # Final layer normalization
        self.ln_f = nn.LayerNorm(hidden_dim)
        # Output projection to vocabulary size
        self.output_proj = nn.Linear(hidden_dim, vocab_size, bias=False)

    def forward(self, x):
        """
        x: Tensor of token IDs with shape (batch_size, seq_length).
        Returns: Logits of shape (batch_size, seq_length, vocab_size).
        """
        batch_size, seq_length = x.shape
        # Token and positional embeddings
        tok_emb = self.tok_embedding(x)  # (batch, seq_len, hidden_dim)
        positions = torch.arange(0, seq_length, device=x.device).unsqueeze(0)
        pos_emb = self.pos_embedding(positions)  # (1, seq_len, hidden_dim)
        h = self.dropout(tok_emb + pos_emb)  # (batch, seq_len, hidden_dim)
        # Transformer decoder blocks
        for layer in self.layers:
            h = layer(h)
        # Final layer norm
        h = self.ln_f(h)
        # Compute logits
        logits = self.output_proj(h)  # (batch, seq_len, vocab_size)
        return logits