scorcher / train.py
seankross's picture
Upload 6 files
63b0b0b
raw
history blame contribute delete
No virus
3.09 kB
# # Hyperparameters
# d_model = 512 # Dimension of the embeddings and the token representations
# seq_length = 10 # Length of the input and output sequences
# vocab_size = 1000 # Size of the vocabulary
# batch_size = 32 # Batch size for training
# num_heads = 8 # Number of heads in multi-head attention
# dim_feedforward = 2048 # Dimension of feedforward network in encoder and decoder
# max_seq_length = seq_length # Maximum sequence length
# # Initialize the model
# transformer_model = TransformerModel(vocab_size, d_model, num_heads, dim_feedforward, max_seq_length)
transformer_model = transformer_model.to(device)
# Define loss function and optimizer
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(transformer_model.parameters(), lr=0.001)
# Function to generate a square subsequent mask
def generate_square_subsequent_mask(sz):
mask = torch.triu(torch.ones(sz, sz, device=device), diagonal=1).bool()
return mask
# Function to create padding mask
def create_padding_mask(seq):
mask = (seq == 0).transpose(0, 1) # Assuming 0 is the padding index
return mask
# Training Loop
num_epochs = 10
for epoch in range(num_epochs):
transformer_model.train()
epoch_loss = 0
for batch_idx, batch in enumerate(train_data_loader):
src_batch = batch[:, :-1].to(device)
tgt_input = batch[:, :-1].to(device) # Excludes the <eos> token
targets = batch[:, 1:].contiguous().view(-1).to(device) # Shifted by one for teacher forcing
optimizer.zero_grad()
# Generate square subsequent masks and padding masks
src_seq_len = src_batch.size(0)#(1)
tgt_seq_len = tgt_input.size(0)#(1)
src_mask = generate_square_subsequent_mask(src_seq_len)
tgt_mask = generate_square_subsequent_mask(tgt_seq_len)
src_padding_mask = create_padding_mask(src_batch)
tgt_padding_mask = create_padding_mask(tgt_input)
# Forward pass
output = transformer_model(
src_batch.transpose(0,1), tgt_input.transpose(0,1),
src_mask, tgt_mask,
src_padding_mask, tgt_padding_mask
)
# src_batch.size() # torch.Size([9, 10]) (A, B)
# tgt_input.size() # torch.Size([9, 10]) (A, B)
# src_mask.size() # torch.Size([9, 9]) (A, A)
# tgt_mask.size() # torch.Size([9, 9]) (A, A)
# src_padding_mask.size() # torch.Size([10, 9]) (C, A)
# tgt_padding_mask.size() # torch.Size([10, 9]) (C, A)
output = output.view(-1, vocab_size)
loss = loss_fn(output, targets)
loss.backward()
optimizer.step()
epoch_loss += loss.item()
if batch_idx % 100 == 0: # Print loss every 100 batches
print(f"Epoch {epoch+1}/{num_epochs}, Batch {batch_idx}, Loss: {epoch_loss / (batch_idx+1):.4f}")
print(f"Epoch {epoch+1}/{num_epochs} finished, Average Loss: {epoch_loss / len(train_data_loader):.4f}")
# Save the model and vocab
torch.save(transformer_model.state_dict(), 'transformer_model.pth')
import json
with open('vocabulary.json', 'w') as vocab_file:
json.dump(dataset.vocab, vocab_file)