seankross commited on
Commit
63b0b0b
1 Parent(s): f533fae

Upload 6 files

Browse files
Files changed (6) hide show
  1. load_data.py +71 -0
  2. main.py +9 -0
  3. predict.py +118 -0
  4. tiny-shakespeare.txt +0 -0
  5. train.py +83 -0
  6. transformer_model.py +188 -0
load_data.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import numpy as np
4
+ from torch.nn.utils.rnn import pad_sequence
5
+ from torch.utils.data import DataLoader, Dataset
6
+ from collections import Counter
7
+ from itertools import chain
8
+
9
+
10
+ # # Hyperparameters
11
+ # d_model = 512 # Dimension of the embeddings and the token representations
12
+ seq_length = 10 # Length of the input and output sequences
13
+ # vocab_size = 1000 # Size of the vocabulary
14
+ batch_size = 32 # Batch size for training
15
+ # num_heads = 8 # Number of heads in multi-head attention
16
+ # dim_feedforward = 2048 # Dimension of feedforward network in encoder and decoder
17
+
18
+
19
+ # Assuming `transformer_model` and hyperparameters are defined as before
20
+
21
+ class TextDataset(Dataset):
22
+ def __init__(self, text, vocab=None, seq_length=seq_length):
23
+ # Tokenization - simple split by whitespace
24
+ self.tokens = text.split()
25
+
26
+ # If a vocabulary is provided, use it, otherwise create a new one
27
+ if vocab:
28
+ self.vocab = vocab
29
+ else:
30
+ # Build vocabulary from the unique tokens with added <pad> and <eos> tokens
31
+ self.vocab = {'<pad>': 0, '<eos>': 1}
32
+ token_counts = Counter(self.tokens)
33
+ for token, _ in token_counts.items():
34
+ self.vocab[token] = len(self.vocab)
35
+
36
+ # Inverse mapping from indices to tokens
37
+ self.index2token = {index: token for token, index in self.vocab.items()}
38
+
39
+ # Convert tokens to indices
40
+ self.indexed_tokens = [self.vocab[token] for token in self.tokens]
41
+
42
+ # Sequence length
43
+ self.seq_length = seq_length
44
+
45
+ def __len__(self):
46
+ # Number of tokens divided by the sequence length gives the number of sequences
47
+ return len(self.indexed_tokens) // self.seq_length
48
+
49
+ def __getitem__(self, idx):
50
+ # Slice the indexed_tokens to get a sequence
51
+ start_idx = idx * self.seq_length
52
+ end_idx = start_idx + self.seq_length + 1 # +1 for <eos> token
53
+ sequence = self.indexed_tokens[start_idx:end_idx]
54
+ # Convert to torch tensor
55
+ return torch.tensor(sequence, dtype=torch.long)
56
+
57
+ # Load the text from a file
58
+ with open('/Users/sean/Downloads/tiny-shakespeare.txt', 'r', encoding='utf-8') as file:
59
+ text = file.read()
60
+
61
+ # Create the dataset
62
+ dataset = TextDataset(text)
63
+ #vocab_size = len(list(set(dataset.tokens)))
64
+ vocab_size = len(dataset.vocab) # Update vocab_size for the transformer model
65
+ train_data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
66
+
67
+ # Function to generate padding mask for sequences
68
+ def create_padding_mask(seq):
69
+ return (seq == dataset.vocab['<pad>']).transpose(0, 1)
70
+
71
+ # Training loop remains the same as before
main.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ # List of file paths
2
+ file_paths = ["load_data.py", "transformer_model.py", "train.py"]
3
+
4
+ # Loop through the file paths
5
+ for file_path in file_paths:
6
+ # Use exec() to run the code from each file
7
+ with open(file_path, 'r') as file:
8
+ code = file.read()
9
+ exec(code)
predict.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import json
3
+ from transformer_model import TransformerModel
4
+
5
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
6
+
7
+ # # Hyperparameters
8
+ d_model = 512 # Dimension of the embeddings and the token representations
9
+ seq_length = 10 # Length of the input and output sequences
10
+ vocab_size = 25672 # Size of the vocabulary
11
+ batch_size = 32 # Batch size for training
12
+ num_heads = 8 # Number of heads in multi-head attention
13
+ dim_feedforward = 2048 # Dimension of feedforward network in encoder and decoder
14
+
15
+ # Assuming the TransformerModel class is defined in the script
16
+ model = TransformerModel(vocab_size, d_model, num_heads, dim_feedforward, seq_length)
17
+ model.load_state_dict(torch.load('transformer_model.pth'))
18
+ model.eval() # Set the model to evaluation mode
19
+
20
+ # Load the vocabulary
21
+ with open('vocabulary.json', 'r') as vocab_file:
22
+ vocab = json.load(vocab_file)
23
+
24
+ if '<unk>' not in vocab:
25
+ # Assign the next integer index to <unk>
26
+ vocab['<unk>'] = len(vocab)
27
+
28
+
29
+
30
+ def text_to_tensor(text, vocab, seq_length):
31
+ tokens = text.split()
32
+ indices = [vocab.get(token, vocab['<unk>']) for token in tokens] # Replace unknown tokens with <unk>
33
+ indices = indices[:seq_length]
34
+ indices += [vocab['<pad>']] * (seq_length - len(indices))
35
+ return torch.tensor(indices, dtype=torch.long).unsqueeze(0) # Add batch dimension
36
+
37
+
38
+
39
+ input_text = "please make the"
40
+ input_tensor = text_to_tensor(input_text, vocab, seq_length)
41
+ src = input_tensor
42
+ tgt = input_tensor
43
+
44
+
45
+ # def generate_square_subsequent_mask(sz):
46
+ # mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
47
+ # mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
48
+ # return mask
49
+
50
+ # def create_padding_mask(seq):
51
+ # return (seq == vocab['<pad>']).transpose(0, 1)
52
+
53
+ # Function to generate a square subsequent mask
54
+ def generate_square_subsequent_mask(sz):
55
+ mask = torch.triu(torch.ones(sz, sz, device=device), diagonal=1).bool()
56
+ return mask
57
+
58
+ # Function to create padding mask
59
+ def create_padding_mask(seq):
60
+ mask = (seq == 0).transpose(0, 1) # Assuming 0 is the padding index
61
+ return mask
62
+
63
+ src_seq_len = src.size(1)
64
+ tgt_seq_len = tgt.size(1)
65
+
66
+ src_mask = generate_square_subsequent_mask(src_seq_len)
67
+ # src_mask = torch.zeros((src_seq_len, src_seq_len)).type(torch.bool)
68
+ tgt_mask = generate_square_subsequent_mask(tgt_seq_len)
69
+ src_key_padding_mask = create_padding_mask(src)
70
+ tgt_key_padding_mask = create_padding_mask(tgt)
71
+
72
+ src.size()
73
+ tgt.size()
74
+ src_mask.size()
75
+ tgt_mask.size()
76
+ src_key_padding_mask.size()
77
+ tgt_key_padding_mask.size()
78
+
79
+ with torch.no_grad():
80
+ output = model(src, tgt, src_mask, tgt_mask,
81
+ src_key_padding_mask.transpose(0, 1), tgt_key_padding_mask.transpose(0, 1))
82
+
83
+
84
+ predicted_indices = torch.argmax(output, dim=-1).squeeze(0).tolist()
85
+ predicted_indices
86
+
87
+ inverse_vocab = {value: key for key, value in vocab.items()}
88
+
89
+ import itertools
90
+
91
+ flattened_list = list(itertools.chain.from_iterable(predicted_indices))
92
+
93
+ [inverse_vocab[key] for key in flattened_list]
94
+
95
+
96
+ def generate_prediction(text, model, vocab, seq_length):
97
+ model.eval() # Make sure the model is in eval mode
98
+
99
+ # Convert text to tensor
100
+ input_tensor = text_to_tensor(text, vocab, seq_length)
101
+
102
+ # Generate prediction
103
+ with torch.no_grad():
104
+ output = model(input_tensor, input_tensor) # For simplicity, using the same tensor as src and tgt
105
+
106
+ # Convert output tensor to tokens (you may need additional post-processing)
107
+ predicted_indices = torch.argmax(output, dim=-1).squeeze(0).tolist()
108
+ predicted_tokens = [vocab[index] for index in predicted_indices]
109
+
110
+ return predicted_tokens
111
+
112
+
113
+ # Example usage
114
+ text = """Here were the servants of your adversary And
115
+ yours"""
116
+
117
+ prediction = generate_prediction(text, model, vocab, seq_length)
118
+ print(prediction)
tiny-shakespeare.txt ADDED
The diff for this file is too large to render. See raw diff
 
train.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # # Hyperparameters
2
+ # d_model = 512 # Dimension of the embeddings and the token representations
3
+ # seq_length = 10 # Length of the input and output sequences
4
+ # vocab_size = 1000 # Size of the vocabulary
5
+ # batch_size = 32 # Batch size for training
6
+ # num_heads = 8 # Number of heads in multi-head attention
7
+ # dim_feedforward = 2048 # Dimension of feedforward network in encoder and decoder
8
+ # max_seq_length = seq_length # Maximum sequence length
9
+
10
+ # # Initialize the model
11
+ # transformer_model = TransformerModel(vocab_size, d_model, num_heads, dim_feedforward, max_seq_length)
12
+ transformer_model = transformer_model.to(device)
13
+
14
+ # Define loss function and optimizer
15
+ loss_fn = nn.CrossEntropyLoss()
16
+ optimizer = torch.optim.Adam(transformer_model.parameters(), lr=0.001)
17
+
18
+ # Function to generate a square subsequent mask
19
+ def generate_square_subsequent_mask(sz):
20
+ mask = torch.triu(torch.ones(sz, sz, device=device), diagonal=1).bool()
21
+ return mask
22
+
23
+ # Function to create padding mask
24
+ def create_padding_mask(seq):
25
+ mask = (seq == 0).transpose(0, 1) # Assuming 0 is the padding index
26
+ return mask
27
+
28
+ # Training Loop
29
+ num_epochs = 10
30
+ for epoch in range(num_epochs):
31
+ transformer_model.train()
32
+ epoch_loss = 0
33
+
34
+ for batch_idx, batch in enumerate(train_data_loader):
35
+ src_batch = batch[:, :-1].to(device)
36
+ tgt_input = batch[:, :-1].to(device) # Excludes the <eos> token
37
+ targets = batch[:, 1:].contiguous().view(-1).to(device) # Shifted by one for teacher forcing
38
+
39
+ optimizer.zero_grad()
40
+
41
+ # Generate square subsequent masks and padding masks
42
+ src_seq_len = src_batch.size(0)#(1)
43
+ tgt_seq_len = tgt_input.size(0)#(1)
44
+ src_mask = generate_square_subsequent_mask(src_seq_len)
45
+ tgt_mask = generate_square_subsequent_mask(tgt_seq_len)
46
+ src_padding_mask = create_padding_mask(src_batch)
47
+ tgt_padding_mask = create_padding_mask(tgt_input)
48
+
49
+ # Forward pass
50
+ output = transformer_model(
51
+ src_batch.transpose(0,1), tgt_input.transpose(0,1),
52
+ src_mask, tgt_mask,
53
+ src_padding_mask, tgt_padding_mask
54
+ )
55
+
56
+ # src_batch.size() # torch.Size([9, 10]) (A, B)
57
+ # tgt_input.size() # torch.Size([9, 10]) (A, B)
58
+ # src_mask.size() # torch.Size([9, 9]) (A, A)
59
+ # tgt_mask.size() # torch.Size([9, 9]) (A, A)
60
+ # src_padding_mask.size() # torch.Size([10, 9]) (C, A)
61
+ # tgt_padding_mask.size() # torch.Size([10, 9]) (C, A)
62
+
63
+ output = output.view(-1, vocab_size)
64
+ loss = loss_fn(output, targets)
65
+
66
+ loss.backward()
67
+ optimizer.step()
68
+
69
+ epoch_loss += loss.item()
70
+
71
+ if batch_idx % 100 == 0: # Print loss every 100 batches
72
+ print(f"Epoch {epoch+1}/{num_epochs}, Batch {batch_idx}, Loss: {epoch_loss / (batch_idx+1):.4f}")
73
+
74
+ print(f"Epoch {epoch+1}/{num_epochs} finished, Average Loss: {epoch_loss / len(train_data_loader):.4f}")
75
+
76
+ # Save the model and vocab
77
+ torch.save(transformer_model.state_dict(), 'transformer_model.pth')
78
+
79
+ import json
80
+
81
+ with open('vocabulary.json', 'w') as vocab_file:
82
+ json.dump(dataset.vocab, vocab_file)
83
+
transformer_model.py ADDED
@@ -0,0 +1,188 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ # Hyperparameters
6
+ d_model = 512 # Dimension of the embeddings and the token representations
7
+ # batch_size = 32 # Batch size for training
8
+ num_heads = 8 # Number of heads in multi-head attention
9
+ dim_feedforward = 2048 # Dimension of feedforward network in encoder and decoder
10
+
11
+ vocab_size = 25672
12
+ seq_length = 10 # Length of the input and output sequences
13
+
14
+ # To generate text, we need to define several components:
15
+ # 1. Token Embedding: To convert token indices to vectors
16
+ # 2. Positional Encoding: To give the model information about the order of tokens
17
+ # 3. Transformer Block: Consisting of an encoder and a decoder
18
+ # 4. Output Layer: To convert the transformer output to token probabilities
19
+
20
+ # Token Embedding
21
+ class TokenEmbedding(nn.Module):
22
+ def __init__(self, vocab_size, d_model):
23
+ super().__init__()
24
+ self.embedding = nn.Embedding(vocab_size, d_model)
25
+
26
+ def forward(self, tokens):
27
+ return self.embedding(tokens)
28
+
29
+ # max_len=5000
30
+ # pe = torch.zeros(max_len, d_model)
31
+ # position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
32
+ # div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-torch.log(torch.tensor(10000.0)) / d_model))
33
+ # pe[:, 0::2] = torch.sin(position * div_term)
34
+ # pe[:, 1::2] = torch.cos(position * div_term)
35
+ # pe = pe.unsqueeze(0).transpose(0, 1)
36
+ # pe = pe[:max_len]
37
+
38
+ # Positional Encoding
39
+ class PositionalEncoding(nn.Module):
40
+ def __init__(self, d_model, max_len=5000):
41
+ super().__init__()
42
+ self.max_len = max_len
43
+ self.d_model = d_model
44
+ self.register_buffer('pe', self.generate_positional_encoding())
45
+
46
+ def generate_positional_encoding(self):
47
+ # Create a long enough 'pe' for 'max_len'
48
+ pe = torch.zeros(self.max_len, self.d_model)
49
+ position = torch.arange(0, self.max_len, dtype=torch.float).unsqueeze(1)
50
+ div_term = torch.exp(torch.arange(0, self.d_model, 2).float() * (-torch.log(torch.tensor(10000.0)) / self.d_model))
51
+ pe[:, 0::2] = torch.sin(position * div_term)
52
+ pe[:, 1::2] = torch.cos(position * div_term)
53
+ pe = pe.unsqueeze(0).transpose(0, 1)
54
+ return pe[:self.max_len] # Trim to match the desired max_len
55
+
56
+ def forward(self, x):
57
+ # Add positional encoding to the input tensor
58
+ # print("x: " + str(x.size()))
59
+ # print("pe: " + str(pe.size()))
60
+
61
+ # x is of shape [batch size, seq length, d_model]
62
+ # Select positional encoding up to the sequence length of x
63
+ pe = self.pe[:x.size(0), :] # pe is now of shape [seq length, d_model]
64
+ #pez = pe[:src.size(0), :] # pe is now of shape [seq length, d_model]
65
+
66
+ # Add a batch dimension to pe to make it [1, seq length, d_model]
67
+ # Then use expand to match the batch size of x
68
+ # pe = pe.unsqueeze(0).expand(x.size(0), -1, -1) # pe is now of shape [batch size, seq length, d_model]
69
+ # pe = pe.unsqueeze(0).expand(src.size(0), -1, -1)
70
+
71
+ # Broadcast the positional encoding to the entire batch
72
+ # Broadcasting automatically adjusts the dimensions to match x
73
+ x = x + pe
74
+ #x = src + pez
75
+ return x
76
+
77
+
78
+ # Transformer Block
79
+ class TransformerBlock(nn.Module):
80
+ def __init__(self, d_model, num_heads, dim_feedforward, dropout=0.1):
81
+ super().__init__()
82
+ # Encoder Layer
83
+ encoder_layer = nn.TransformerEncoderLayer(
84
+ d_model=d_model,
85
+ nhead=num_heads,
86
+ dim_feedforward=dim_feedforward,
87
+ dropout=dropout
88
+ )
89
+ self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=10)
90
+
91
+ # Decoder Layer
92
+ decoder_layer = nn.TransformerDecoderLayer(
93
+ d_model=d_model,
94
+ nhead=num_heads,
95
+ dim_feedforward=dim_feedforward,
96
+ dropout=dropout
97
+ )
98
+ self.decoder = nn.TransformerDecoder(decoder_layer, num_layers=10)
99
+
100
+ def forward(self, src, tgt, src_mask, tgt_mask, src_key_padding_mask, tgt_key_padding_mask):
101
+ # Forward pass through the encoder and decoder
102
+ memory = self.encoder(src, mask=src_mask, src_key_padding_mask=src_key_padding_mask)
103
+ #memory = encoder(src, mask=src_mask, src_key_padding_mask=src_padding_mask)
104
+ # >>> src.size()
105
+ # torch.Size([32, 10, 512])
106
+ # >>> src_mask.size()
107
+ # torch.Size([10, 10])
108
+ # >>> src_padding_mask.size()
109
+ # torch.Size([10, 32])
110
+ output = self.decoder(tgt, memory, tgt_mask=tgt_mask, memory_mask=None,
111
+ tgt_key_padding_mask=tgt_key_padding_mask,
112
+ memory_key_padding_mask=src_key_padding_mask)
113
+ # output = decoder(tgt, memory, tgt_mask=tgt_mask, memory_mask=None,
114
+ # tgt_key_padding_mask=tgt_padding_mask,
115
+ # memory_key_padding_mask=src_padding_mask)
116
+ return output
117
+
118
+ # Output Layer
119
+ class OutputLayer(nn.Module):
120
+ def __init__(self, d_model, vocab_size):
121
+ super().__init__()
122
+ self.linear = nn.Linear(d_model, vocab_size)
123
+
124
+ def forward(self, x):
125
+ return F.log_softmax(self.linear(x), dim=-1)
126
+
127
+ # src_batch, tgt_input,
128
+ # src_mask, tgt_mask,
129
+ # src_padding_mask, tgt_padding_mask
130
+
131
+ # src = embedding(src_batch) * torch.sqrt(torch.tensor(d_model))
132
+ # tgt = embedding(tgt_input) * torch.sqrt(torch.tensor(d_model))
133
+ # src = pos_encoder(src)
134
+ # tgt = pos_encoder(tgt)
135
+ # transformer_output = transformer_block(src, tgt, src_mask, tgt_mask, src_padding_mask, tgt_padding_mask)
136
+ # out = output_layer(transformer_output)
137
+
138
+
139
+ # Putting it all together
140
+ class TransformerModel(nn.Module):
141
+ def __init__(self, vocab_size, d_model, num_heads, dim_feedforward, max_seq_length):
142
+ super().__init__()
143
+ self.embedding = TokenEmbedding(vocab_size, d_model)
144
+ self.pos_encoder = PositionalEncoding(d_model, max_seq_length)
145
+ #pos_encoder = PositionalEncoding(d_model, seq_length)
146
+ self.transformer_block = TransformerBlock(d_model, num_heads, dim_feedforward)
147
+ self.output_layer = OutputLayer(d_model, vocab_size)
148
+
149
+ def forward(self, src, tgt, src_mask, tgt_mask, src_key_padding_mask, tgt_key_padding_mask):
150
+ src = self.embedding(src) * torch.sqrt(torch.tensor(d_model))
151
+ tgt = self.embedding(tgt) * torch.sqrt(torch.tensor(d_model))
152
+ src = self.pos_encoder(src)
153
+ tgt = self.pos_encoder(tgt)
154
+ # src = self.pos_encoder(src.transpose(0, 1)).transpose(0, 1)
155
+ # tgt = self.pos_encoder(tgt.transpose(0, 1)).transpose(0, 1)
156
+ # src = pos_encoder(src.transpose(0, 1))
157
+ # tgt = pos_encoder(tgt.transpose(0, 1))
158
+ transformer_output = self.transformer_block(src.transpose(0, 1), tgt.transpose(0, 1), src_mask, tgt_mask, src_key_padding_mask, tgt_key_padding_mask)
159
+ #transformer_output = transformer_block(src.transpose(0, 1), tgt.transpose(0, 1), src_mask, tgt_mask, src_padding_mask, tgt_padding_mask)
160
+ out = self.output_layer(transformer_output)
161
+ return out
162
+
163
+ # # Create a sample source and target batch
164
+ # src = torch.randint(low=0, high=vocab_size, size=(seq_length, batch_size))
165
+ # tgt = torch.randint(low=0, high=vocab_size, size=(seq_length, batch_size))
166
+
167
+ # # Masks and Padding
168
+ # src_mask = torch.zeros((seq_length, seq_length)).type(torch.bool)
169
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
170
+ # tgt_mask = nn.Transformer.generate_square_subsequent_mask(seq_length, device=device)
171
+ # #tgt_mask = nn.Transformer.generate_square_subsequent_mask(None, seq_length)
172
+
173
+ # src_key_padding_mask = torch.zeros(batch_size, seq_length).type(torch.bool) # Assuming no padding
174
+ # tgt_key_padding_mask = torch.zeros(batch_size, seq_length).type(torch.bool) # Assuming no padding
175
+
176
+ # Initialize the model
177
+ transformer_model = TransformerModel(vocab_size, d_model, num_heads, dim_feedforward, seq_length)
178
+
179
+ # Forward pass
180
+ #output = transformer_model(src, tgt, src_mask, tgt_mask, src_key_padding_mask, tgt_key_padding_mask)
181
+
182
+ # Show the output size
183
+ #print(output.size()) # (seq_length, batch_size, vocab_size)
184
+
185
+ # In this code, we define a simple Transformer model for educational purposes.
186
+ # It can be expanded by increasing the number of layers, adding dropout, and including more complex masking.
187
+ # The actual training loop, loss calculation, and optimization steps are not shown.
188
+