|
import torch
|
|
import torch.nn as nn
|
|
|
|
class MiniGPT(nn.Module):
|
|
def __init__(self, vocab_size, d_model=1024, n_heads=16, n_layers=24, max_len=512):
|
|
super().__init__()
|
|
|
|
self.token_embed = nn.Embedding(vocab_size, d_model)
|
|
self.pos_embed = nn.Embedding(max_len, d_model)
|
|
|
|
|
|
|
|
encoder_layer = nn.TransformerEncoderLayer(d_model=d_model, nhead=n_heads, dropout=0.0, batch_first=False)
|
|
self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=n_layers)
|
|
|
|
self.ln = nn.LayerNorm(d_model)
|
|
self.fc_out = nn.Linear(d_model, vocab_size)
|
|
|
|
def generate_causal_mask(self, T, device):
|
|
|
|
return torch.triu(torch.ones(T, T, device=device), diagonal=1).bool()
|
|
|
|
def forward(self, input_ids):
|
|
B, T = input_ids.shape
|
|
pos = torch.arange(0, T, device=input_ids.device).unsqueeze(0)
|
|
x = self.token_embed(input_ids) + self.pos_embed(pos)
|
|
x = x.transpose(0, 1)
|
|
|
|
|
|
mask = self.generate_causal_mask(T, input_ids.device)
|
|
|
|
x = self.transformer(x, mask)
|
|
x = x.transpose(0, 1)
|
|
x = self.ln(x)
|
|
return self.fc_out(x)
|
|
|
|
def reset_params(self):
|
|
for layer in self.children():
|
|
if hasattr(layer, 'reset_parameters'):
|
|
layer.reset_parameters() |