File size: 1,619 Bytes
4de3b20
 
 
 
79eec1d
4de3b20
79eec1d
4de3b20
 
79eec1d
 
 
 
4de3b20
79eec1d
4de3b20
 
 
79eec1d
 
 
 
4de3b20
 
 
 
 
79eec1d
 
 
 
 
4de3b20
 
 
79eec1d
4de3b20
 
79eec1d
4de3b20
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
import torch
import torch.nn as nn

class MiniGPT(nn.Module):
    def __init__(self, vocab_size, d_model=1024, n_heads=16, n_layers=24, max_len=512):
        super().__init__()

        self.token_embed = nn.Embedding(vocab_size, d_model)
        self.pos_embed = nn.Embedding(max_len, d_model)

        # 🎯 CHANGE 1: Set dropout to 0.0 for debugging underfitting on tiny data
        # This allows the model to memorize the small dataset.
        encoder_layer = nn.TransformerEncoderLayer(d_model=d_model, nhead=n_heads, dropout=0.0, batch_first=False)
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=n_layers)

        self.ln = nn.LayerNorm(d_model)
        self.fc_out = nn.Linear(d_model, vocab_size)

    def generate_causal_mask(self, T, device):
        # This mask is correct for a TransformerEncoder used causally (True masks future tokens)
        return torch.triu(torch.ones(T, T, device=device), diagonal=1).bool()

    def forward(self, input_ids):
        B, T = input_ids.shape
        pos = torch.arange(0, T, device=input_ids.device).unsqueeze(0)
        x = self.token_embed(input_ids) + self.pos_embed(pos)
        x = x.transpose(0, 1)  # [T, B, D]

        # Causal Mask
        mask = self.generate_causal_mask(T, input_ids.device)

        x = self.transformer(x, mask)
        x = x.transpose(0, 1)  # [B, T, D]
        x = self.ln(x)
        return self.fc_out(x)

    def reset_params(self):
        for layer in self.children():
            if hasattr(layer, 'reset_parameters'):
                layer.reset_parameters()