MiniGPT / model.py
CreatedNull's picture
Upload folder using huggingface_hub
79eec1d verified
import torch
import torch.nn as nn
class MiniGPT(nn.Module):
def __init__(self, vocab_size, d_model=1024, n_heads=16, n_layers=24, max_len=512):
super().__init__()
self.token_embed = nn.Embedding(vocab_size, d_model)
self.pos_embed = nn.Embedding(max_len, d_model)
# 🎯 CHANGE 1: Set dropout to 0.0 for debugging underfitting on tiny data
# This allows the model to memorize the small dataset.
encoder_layer = nn.TransformerEncoderLayer(d_model=d_model, nhead=n_heads, dropout=0.0, batch_first=False)
self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=n_layers)
self.ln = nn.LayerNorm(d_model)
self.fc_out = nn.Linear(d_model, vocab_size)
def generate_causal_mask(self, T, device):
# This mask is correct for a TransformerEncoder used causally (True masks future tokens)
return torch.triu(torch.ones(T, T, device=device), diagonal=1).bool()
def forward(self, input_ids):
B, T = input_ids.shape
pos = torch.arange(0, T, device=input_ids.device).unsqueeze(0)
x = self.token_embed(input_ids) + self.pos_embed(pos)
x = x.transpose(0, 1) # [T, B, D]
# Causal Mask
mask = self.generate_causal_mask(T, input_ids.device)
x = self.transformer(x, mask)
x = x.transpose(0, 1) # [B, T, D]
x = self.ln(x)
return self.fc_out(x)
def reset_params(self):
for layer in self.children():
if hasattr(layer, 'reset_parameters'):
layer.reset_parameters()