scorcher / transformer_model.py
seankross's picture
Upload 6 files
63b0b0b
raw
history blame contribute delete
No virus
8.43 kB
import torch
import torch.nn as nn
import torch.nn.functional as F
# Hyperparameters
d_model = 512 # Dimension of the embeddings and the token representations
# batch_size = 32 # Batch size for training
num_heads = 8 # Number of heads in multi-head attention
dim_feedforward = 2048 # Dimension of feedforward network in encoder and decoder
vocab_size = 25672
seq_length = 10 # Length of the input and output sequences
# To generate text, we need to define several components:
# 1. Token Embedding: To convert token indices to vectors
# 2. Positional Encoding: To give the model information about the order of tokens
# 3. Transformer Block: Consisting of an encoder and a decoder
# 4. Output Layer: To convert the transformer output to token probabilities
# Token Embedding
class TokenEmbedding(nn.Module):
def __init__(self, vocab_size, d_model):
super().__init__()
self.embedding = nn.Embedding(vocab_size, d_model)
def forward(self, tokens):
return self.embedding(tokens)
# max_len=5000
# pe = torch.zeros(max_len, d_model)
# position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
# div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-torch.log(torch.tensor(10000.0)) / d_model))
# pe[:, 0::2] = torch.sin(position * div_term)
# pe[:, 1::2] = torch.cos(position * div_term)
# pe = pe.unsqueeze(0).transpose(0, 1)
# pe = pe[:max_len]
# Positional Encoding
class PositionalEncoding(nn.Module):
def __init__(self, d_model, max_len=5000):
super().__init__()
self.max_len = max_len
self.d_model = d_model
self.register_buffer('pe', self.generate_positional_encoding())
def generate_positional_encoding(self):
# Create a long enough 'pe' for 'max_len'
pe = torch.zeros(self.max_len, self.d_model)
position = torch.arange(0, self.max_len, dtype=torch.float).unsqueeze(1)
div_term = torch.exp(torch.arange(0, self.d_model, 2).float() * (-torch.log(torch.tensor(10000.0)) / self.d_model))
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0).transpose(0, 1)
return pe[:self.max_len] # Trim to match the desired max_len
def forward(self, x):
# Add positional encoding to the input tensor
# print("x: " + str(x.size()))
# print("pe: " + str(pe.size()))
# x is of shape [batch size, seq length, d_model]
# Select positional encoding up to the sequence length of x
pe = self.pe[:x.size(0), :] # pe is now of shape [seq length, d_model]
#pez = pe[:src.size(0), :] # pe is now of shape [seq length, d_model]
# Add a batch dimension to pe to make it [1, seq length, d_model]
# Then use expand to match the batch size of x
# pe = pe.unsqueeze(0).expand(x.size(0), -1, -1) # pe is now of shape [batch size, seq length, d_model]
# pe = pe.unsqueeze(0).expand(src.size(0), -1, -1)
# Broadcast the positional encoding to the entire batch
# Broadcasting automatically adjusts the dimensions to match x
x = x + pe
#x = src + pez
return x
# Transformer Block
class TransformerBlock(nn.Module):
def __init__(self, d_model, num_heads, dim_feedforward, dropout=0.1):
super().__init__()
# Encoder Layer
encoder_layer = nn.TransformerEncoderLayer(
d_model=d_model,
nhead=num_heads,
dim_feedforward=dim_feedforward,
dropout=dropout
)
self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=10)
# Decoder Layer
decoder_layer = nn.TransformerDecoderLayer(
d_model=d_model,
nhead=num_heads,
dim_feedforward=dim_feedforward,
dropout=dropout
)
self.decoder = nn.TransformerDecoder(decoder_layer, num_layers=10)
def forward(self, src, tgt, src_mask, tgt_mask, src_key_padding_mask, tgt_key_padding_mask):
# Forward pass through the encoder and decoder
memory = self.encoder(src, mask=src_mask, src_key_padding_mask=src_key_padding_mask)
#memory = encoder(src, mask=src_mask, src_key_padding_mask=src_padding_mask)
# >>> src.size()
# torch.Size([32, 10, 512])
# >>> src_mask.size()
# torch.Size([10, 10])
# >>> src_padding_mask.size()
# torch.Size([10, 32])
output = self.decoder(tgt, memory, tgt_mask=tgt_mask, memory_mask=None,
tgt_key_padding_mask=tgt_key_padding_mask,
memory_key_padding_mask=src_key_padding_mask)
# output = decoder(tgt, memory, tgt_mask=tgt_mask, memory_mask=None,
# tgt_key_padding_mask=tgt_padding_mask,
# memory_key_padding_mask=src_padding_mask)
return output
# Output Layer
class OutputLayer(nn.Module):
def __init__(self, d_model, vocab_size):
super().__init__()
self.linear = nn.Linear(d_model, vocab_size)
def forward(self, x):
return F.log_softmax(self.linear(x), dim=-1)
# src_batch, tgt_input,
# src_mask, tgt_mask,
# src_padding_mask, tgt_padding_mask
# src = embedding(src_batch) * torch.sqrt(torch.tensor(d_model))
# tgt = embedding(tgt_input) * torch.sqrt(torch.tensor(d_model))
# src = pos_encoder(src)
# tgt = pos_encoder(tgt)
# transformer_output = transformer_block(src, tgt, src_mask, tgt_mask, src_padding_mask, tgt_padding_mask)
# out = output_layer(transformer_output)
# Putting it all together
class TransformerModel(nn.Module):
def __init__(self, vocab_size, d_model, num_heads, dim_feedforward, max_seq_length):
super().__init__()
self.embedding = TokenEmbedding(vocab_size, d_model)
self.pos_encoder = PositionalEncoding(d_model, max_seq_length)
#pos_encoder = PositionalEncoding(d_model, seq_length)
self.transformer_block = TransformerBlock(d_model, num_heads, dim_feedforward)
self.output_layer = OutputLayer(d_model, vocab_size)
def forward(self, src, tgt, src_mask, tgt_mask, src_key_padding_mask, tgt_key_padding_mask):
src = self.embedding(src) * torch.sqrt(torch.tensor(d_model))
tgt = self.embedding(tgt) * torch.sqrt(torch.tensor(d_model))
src = self.pos_encoder(src)
tgt = self.pos_encoder(tgt)
# src = self.pos_encoder(src.transpose(0, 1)).transpose(0, 1)
# tgt = self.pos_encoder(tgt.transpose(0, 1)).transpose(0, 1)
# src = pos_encoder(src.transpose(0, 1))
# tgt = pos_encoder(tgt.transpose(0, 1))
transformer_output = self.transformer_block(src.transpose(0, 1), tgt.transpose(0, 1), src_mask, tgt_mask, src_key_padding_mask, tgt_key_padding_mask)
#transformer_output = transformer_block(src.transpose(0, 1), tgt.transpose(0, 1), src_mask, tgt_mask, src_padding_mask, tgt_padding_mask)
out = self.output_layer(transformer_output)
return out
# # Create a sample source and target batch
# src = torch.randint(low=0, high=vocab_size, size=(seq_length, batch_size))
# tgt = torch.randint(low=0, high=vocab_size, size=(seq_length, batch_size))
# # Masks and Padding
# src_mask = torch.zeros((seq_length, seq_length)).type(torch.bool)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# tgt_mask = nn.Transformer.generate_square_subsequent_mask(seq_length, device=device)
# #tgt_mask = nn.Transformer.generate_square_subsequent_mask(None, seq_length)
# src_key_padding_mask = torch.zeros(batch_size, seq_length).type(torch.bool) # Assuming no padding
# tgt_key_padding_mask = torch.zeros(batch_size, seq_length).type(torch.bool) # Assuming no padding
# Initialize the model
transformer_model = TransformerModel(vocab_size, d_model, num_heads, dim_feedforward, seq_length)
# Forward pass
#output = transformer_model(src, tgt, src_mask, tgt_mask, src_key_padding_mask, tgt_key_padding_mask)
# Show the output size
#print(output.size()) # (seq_length, batch_size, vocab_size)
# In this code, we define a simple Transformer model for educational purposes.
# It can be expanded by increasing the number of layers, adding dropout, and including more complex masking.
# The actual training loop, loss calculation, and optimization steps are not shown.