File size: 3,092 Bytes
63b0b0b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
# # Hyperparameters
# d_model = 512  # Dimension of the embeddings and the token representations
# seq_length = 10  # Length of the input and output sequences
# vocab_size = 1000  # Size of the vocabulary
# batch_size = 32  # Batch size for training
# num_heads = 8  # Number of heads in multi-head attention
# dim_feedforward = 2048  # Dimension of feedforward network in encoder and decoder
# max_seq_length = seq_length # Maximum sequence length

# # Initialize the model
# transformer_model = TransformerModel(vocab_size, d_model, num_heads, dim_feedforward, max_seq_length)
transformer_model = transformer_model.to(device)

# Define loss function and optimizer
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(transformer_model.parameters(), lr=0.001)

# Function to generate a square subsequent mask
def generate_square_subsequent_mask(sz):
    mask = torch.triu(torch.ones(sz, sz, device=device), diagonal=1).bool()
    return mask

# Function to create padding mask
def create_padding_mask(seq):
    mask = (seq == 0).transpose(0, 1) # Assuming 0 is the padding index
    return mask

# Training Loop
num_epochs = 10
for epoch in range(num_epochs):
    transformer_model.train()
    epoch_loss = 0

    for batch_idx, batch in enumerate(train_data_loader):
        src_batch = batch[:, :-1].to(device)
        tgt_input = batch[:, :-1].to(device)  # Excludes the <eos> token
        targets = batch[:, 1:].contiguous().view(-1).to(device)  # Shifted by one for teacher forcing

        optimizer.zero_grad()

        # Generate square subsequent masks and padding masks
        src_seq_len = src_batch.size(0)#(1)
        tgt_seq_len = tgt_input.size(0)#(1)
        src_mask = generate_square_subsequent_mask(src_seq_len)
        tgt_mask = generate_square_subsequent_mask(tgt_seq_len)
        src_padding_mask = create_padding_mask(src_batch)
        tgt_padding_mask = create_padding_mask(tgt_input)

        # Forward pass
        output = transformer_model(
            src_batch.transpose(0,1), tgt_input.transpose(0,1),
            src_mask, tgt_mask,
            src_padding_mask, tgt_padding_mask
        )

# src_batch.size() # torch.Size([9, 10]) (A, B)
# tgt_input.size() # torch.Size([9, 10]) (A, B)
# src_mask.size() # torch.Size([9, 9]) (A, A)
# tgt_mask.size() # torch.Size([9, 9]) (A, A)
# src_padding_mask.size() # torch.Size([10, 9]) (C, A)
# tgt_padding_mask.size() # torch.Size([10, 9]) (C, A)

        output = output.view(-1, vocab_size)
        loss = loss_fn(output, targets)

        loss.backward()
        optimizer.step()

        epoch_loss += loss.item()

        if batch_idx % 100 == 0:  # Print loss every 100 batches
            print(f"Epoch {epoch+1}/{num_epochs}, Batch {batch_idx}, Loss: {epoch_loss / (batch_idx+1):.4f}")

    print(f"Epoch {epoch+1}/{num_epochs} finished, Average Loss: {epoch_loss / len(train_data_loader):.4f}")

# Save the model and vocab
torch.save(transformer_model.state_dict(), 'transformer_model.pth')

import json

with open('vocabulary.json', 'w') as vocab_file:
    json.dump(dataset.vocab, vocab_file)