pt-sk commited on
Commit
86b15cd
1 Parent(s): f5265c2

Upload 4 files

Browse files
Files changed (4) hide show
  1. config.py +33 -0
  2. dataset.py +90 -0
  3. model.py +267 -0
  4. train.py +274 -0
config.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+
3
+ def get_config():
4
+ return {
5
+ "batch_size": 4,
6
+ "num_epochs": 10,
7
+ "lr": 10**-4,
8
+ "seq_len": 350,
9
+ "d_model": 512,
10
+ "datasource": 'opus_books',
11
+ "lang_src": "en",
12
+ "lang_tgt": "it",
13
+ "model_folder": "weights",
14
+ "model_basename": "tmodel_",
15
+ "preload": "latest",
16
+ "tokenizer_file": "tokenizer_{0}.json",
17
+ "experiment_name": "runs/tmodel"
18
+ }
19
+
20
+ def get_weights_file_path(config, epoch: str):
21
+ model_folder = f"{config['datasource']}_{config['model_folder']}"
22
+ model_filename = f"{config['model_basename']}{epoch}.pt"
23
+ return str(Path('.') / model_folder / model_filename)
24
+
25
+ # Find the latest weights file in the weights folder
26
+ def latest_weights_file_path(config):
27
+ model_folder = f"{config['datasource']}_{config['model_folder']}"
28
+ model_filename = f"{config['model_basename']}*"
29
+ weights_files = list(Path(model_folder).glob(model_filename))
30
+ if len(weights_files) == 0:
31
+ return None
32
+ weights_files.sort()
33
+ return str(weights_files[-1])
dataset.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from torch.utils.data import Dataset
4
+
5
+ class BilingualDataset(Dataset):
6
+
7
+ def __init__(self, ds, tokenizer_src, tokenizer_tgt, src_lang, tgt_lang, seq_len):
8
+ super().__init__()
9
+ self.seq_len = seq_len
10
+
11
+ self.ds = ds
12
+ self.tokenizer_src = tokenizer_src
13
+ self.tokenizer_tgt = tokenizer_tgt
14
+ self.src_lang = src_lang
15
+ self.tgt_lang = tgt_lang
16
+
17
+ self.sos_token = torch.tensor([tokenizer_tgt.token_to_id("[SOS]")], dtype=torch.int64)
18
+ self.eos_token = torch.tensor([tokenizer_tgt.token_to_id("[EOS]")], dtype=torch.int64)
19
+ self.pad_token = torch.tensor([tokenizer_tgt.token_to_id("[PAD]")], dtype=torch.int64)
20
+
21
+ def __len__(self):
22
+ return len(self.ds)
23
+
24
+ def __getitem__(self, idx):
25
+ src_target_pair = self.ds[idx]
26
+ src_text = src_target_pair['translation'][self.src_lang]
27
+ tgt_text = src_target_pair['translation'][self.tgt_lang]
28
+
29
+ # Transform the text into tokens
30
+ enc_input_tokens = self.tokenizer_src.encode(src_text).ids
31
+ dec_input_tokens = self.tokenizer_tgt.encode(tgt_text).ids
32
+
33
+ # Add sos, eos and padding to each sentence
34
+ enc_num_padding_tokens = self.seq_len - len(enc_input_tokens) - 2 # We will add <s> and </s>
35
+ # We will only add <s>, and </s> only on the label
36
+ dec_num_padding_tokens = self.seq_len - len(dec_input_tokens) - 1
37
+
38
+ # Make sure the number of padding tokens is not negative. If it is, the sentence is too long
39
+ if enc_num_padding_tokens < 0 or dec_num_padding_tokens < 0:
40
+ raise ValueError("Sentence is too long")
41
+
42
+ # Add <s> and </s> token
43
+ encoder_input = torch.cat(
44
+ [
45
+ self.sos_token,
46
+ torch.tensor(enc_input_tokens, dtype=torch.int64),
47
+ self.eos_token,
48
+ torch.tensor([self.pad_token] * enc_num_padding_tokens, dtype=torch.int64),
49
+ ],
50
+ dim=0,
51
+ )
52
+
53
+ # Add only <s> token
54
+ decoder_input = torch.cat(
55
+ [
56
+ self.sos_token,
57
+ torch.tensor(dec_input_tokens, dtype=torch.int64),
58
+ torch.tensor([self.pad_token] * dec_num_padding_tokens, dtype=torch.int64),
59
+ ],
60
+ dim=0,
61
+ )
62
+
63
+ # Add only </s> token
64
+ label = torch.cat(
65
+ [
66
+ torch.tensor(dec_input_tokens, dtype=torch.int64),
67
+ self.eos_token,
68
+ torch.tensor([self.pad_token] * dec_num_padding_tokens, dtype=torch.int64),
69
+ ],
70
+ dim=0,
71
+ )
72
+
73
+ # Double check the size of the tensors to make sure they are all seq_len long
74
+ assert encoder_input.size(0) == self.seq_len
75
+ assert decoder_input.size(0) == self.seq_len
76
+ assert label.size(0) == self.seq_len
77
+
78
+ return {
79
+ "encoder_input": encoder_input, # (seq_len)
80
+ "decoder_input": decoder_input, # (seq_len)
81
+ "encoder_mask": (encoder_input != self.pad_token).unsqueeze(0).unsqueeze(0).int(), # (1, 1, seq_len)
82
+ "decoder_mask": (decoder_input != self.pad_token).unsqueeze(0).int() & causal_mask(decoder_input.size(0)), # (1, seq_len) & (1, seq_len, seq_len),
83
+ "label": label, # (seq_len)
84
+ "src_text": src_text,
85
+ "tgt_text": tgt_text,
86
+ }
87
+
88
+ def causal_mask(size):
89
+ mask = torch.triu(torch.ones((1, size, size)), diagonal=1).type(torch.int)
90
+ return mask == 0
model.py ADDED
@@ -0,0 +1,267 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import math
4
+
5
+ class LayerNormalization(nn.Module):
6
+
7
+ def __init__(self, features: int, eps:float=10**-6) -> None:
8
+ super().__init__()
9
+ self.eps = eps
10
+ self.alpha = nn.Parameter(torch.ones(features)) # alpha is a learnable parameter
11
+ self.bias = nn.Parameter(torch.zeros(features)) # bias is a learnable parameter
12
+
13
+ def forward(self, x):
14
+ # x: (batch, seq_len, hidden_size)
15
+ # Keep the dimension for broadcasting
16
+ mean = x.mean(dim = -1, keepdim = True) # (batch, seq_len, 1)
17
+ # Keep the dimension for broadcasting
18
+ std = x.std(dim = -1, keepdim = True) # (batch, seq_len, 1)
19
+ # eps is to prevent dividing by zero or when std is very small
20
+ return self.alpha * (x - mean) / (std + self.eps) + self.bias
21
+
22
+ class FeedForwardBlock(nn.Module):
23
+
24
+ def __init__(self, d_model: int, d_ff: int, dropout: float) -> None:
25
+ super().__init__()
26
+ self.linear_1 = nn.Linear(d_model, d_ff) # w1 and b1
27
+ self.dropout = nn.Dropout(dropout)
28
+ self.linear_2 = nn.Linear(d_ff, d_model) # w2 and b2
29
+
30
+ def forward(self, x):
31
+ # (batch, seq_len, d_model) --> (batch, seq_len, d_ff) --> (batch, seq_len, d_model)
32
+ return self.linear_2(self.dropout(torch.relu(self.linear_1(x))))
33
+
34
+ class InputEmbeddings(nn.Module):
35
+
36
+ def __init__(self, d_model: int, vocab_size: int) -> None:
37
+ super().__init__()
38
+ self.d_model = d_model
39
+ self.vocab_size = vocab_size
40
+ self.embedding = nn.Embedding(vocab_size, d_model)
41
+
42
+ def forward(self, x):
43
+ # (batch, seq_len) --> (batch, seq_len, d_model)
44
+ # Multiply by sqrt(d_model) to scale the embeddings according to the paper
45
+ return self.embedding(x) * math.sqrt(self.d_model)
46
+
47
+ class PositionalEncoding(nn.Module):
48
+
49
+ def __init__(self, d_model: int, seq_len: int, dropout: float) -> None:
50
+ super().__init__()
51
+ self.d_model = d_model
52
+ self.seq_len = seq_len
53
+ self.dropout = nn.Dropout(dropout)
54
+ # Create a matrix of shape (seq_len, d_model)
55
+ pe = torch.zeros(seq_len, d_model)
56
+ # Create a vector of shape (seq_len)
57
+ position = torch.arange(0, seq_len, dtype=torch.float).unsqueeze(1) # (seq_len, 1)
58
+ # Create a vector of shape (d_model)
59
+ div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)) # (d_model / 2)
60
+ # Apply sine to even indices
61
+ pe[:, 0::2] = torch.sin(position * div_term) # sin(position * (10000 ** (2i / d_model))
62
+ # Apply cosine to odd indices
63
+ pe[:, 1::2] = torch.cos(position * div_term) # cos(position * (10000 ** (2i / d_model))
64
+ # Add a batch dimension to the positional encoding
65
+ pe = pe.unsqueeze(0) # (1, seq_len, d_model)
66
+ # Register the positional encoding as a buffer
67
+ self.register_buffer('pe', pe)
68
+
69
+ def forward(self, x):
70
+ x = x + (self.pe[:, :x.shape[1], :]).requires_grad_(False) # (batch, seq_len, d_model)
71
+ return self.dropout(x)
72
+
73
+ class ResidualConnection(nn.Module):
74
+
75
+ def __init__(self, features: int, dropout: float) -> None:
76
+ super().__init__()
77
+ self.dropout = nn.Dropout(dropout)
78
+ self.norm = LayerNormalization(features)
79
+
80
+ def forward(self, x, sublayer):
81
+ return x + self.dropout(sublayer(self.norm(x)))
82
+
83
+ class MultiHeadAttentionBlock(nn.Module):
84
+
85
+ def __init__(self, d_model: int, h: int, dropout: float) -> None:
86
+ super().__init__()
87
+ self.d_model = d_model # Embedding vector size
88
+ self.h = h # Number of heads
89
+ # Make sure d_model is divisible by h
90
+ assert d_model % h == 0, "d_model is not divisible by h"
91
+
92
+ self.d_k = d_model // h # Dimension of vector seen by each head
93
+ self.w_q = nn.Linear(d_model, d_model, bias=False) # Wq
94
+ self.w_k = nn.Linear(d_model, d_model, bias=False) # Wk
95
+ self.w_v = nn.Linear(d_model, d_model, bias=False) # Wv
96
+ self.w_o = nn.Linear(d_model, d_model, bias=False) # Wo
97
+ self.dropout = nn.Dropout(dropout)
98
+
99
+ @staticmethod
100
+ def attention(query, key, value, mask, dropout: nn.Dropout):
101
+ d_k = query.shape[-1]
102
+ # Just apply the formula from the paper
103
+ # (batch, h, seq_len, d_k) --> (batch, h, seq_len, seq_len)
104
+ attention_scores = (query @ key.transpose(-2, -1)) / math.sqrt(d_k)
105
+ if mask is not None:
106
+ # Write a very low value (indicating -inf) to the positions where mask == 0
107
+ attention_scores.masked_fill_(mask == 0, -1e9)
108
+ attention_scores = attention_scores.softmax(dim=-1) # (batch, h, seq_len, seq_len) # Apply softmax
109
+ if dropout is not None:
110
+ attention_scores = dropout(attention_scores)
111
+ # (batch, h, seq_len, seq_len) --> (batch, h, seq_len, d_k)
112
+ # return attention scores which can be used for visualization
113
+ return (attention_scores @ value), attention_scores
114
+
115
+ def forward(self, q, k, v, mask):
116
+ query = self.w_q(q) # (batch, seq_len, d_model) --> (batch, seq_len, d_model)
117
+ key = self.w_k(k) # (batch, seq_len, d_model) --> (batch, seq_len, d_model)
118
+ value = self.w_v(v) # (batch, seq_len, d_model) --> (batch, seq_len, d_model)
119
+
120
+ # (batch, seq_len, d_model) --> (batch, seq_len, h, d_k) --> (batch, h, seq_len, d_k)
121
+ query = query.view(query.shape[0], query.shape[1], self.h, self.d_k).transpose(1, 2)
122
+ key = key.view(key.shape[0], key.shape[1], self.h, self.d_k).transpose(1, 2)
123
+ value = value.view(value.shape[0], value.shape[1], self.h, self.d_k).transpose(1, 2)
124
+
125
+ # Calculate attention
126
+ x, self.attention_scores = MultiHeadAttentionBlock.attention(query, key, value, mask, self.dropout)
127
+
128
+ # Combine all the heads together
129
+ # (batch, h, seq_len, d_k) --> (batch, seq_len, h, d_k) --> (batch, seq_len, d_model)
130
+ x = x.transpose(1, 2).contiguous().view(x.shape[0], -1, self.h * self.d_k)
131
+
132
+ # Multiply by Wo
133
+ # (batch, seq_len, d_model) --> (batch, seq_len, d_model)
134
+ return self.w_o(x)
135
+
136
+ class EncoderBlock(nn.Module):
137
+
138
+ def __init__(self, features: int, self_attention_block: MultiHeadAttentionBlock, feed_forward_block: FeedForwardBlock, dropout: float) -> None:
139
+ super().__init__()
140
+ self.self_attention_block = self_attention_block
141
+ self.feed_forward_block = feed_forward_block
142
+ self.residual_connections = nn.ModuleList([ResidualConnection(features, dropout) for _ in range(2)])
143
+
144
+ def forward(self, x, src_mask):
145
+ x = self.residual_connections[0](x, lambda x: self.self_attention_block(x, x, x, src_mask))
146
+ x = self.residual_connections[1](x, self.feed_forward_block)
147
+ return x
148
+
149
+ class Encoder(nn.Module):
150
+
151
+ def __init__(self, features: int, layers: nn.ModuleList) -> None:
152
+ super().__init__()
153
+ self.layers = layers
154
+ self.norm = LayerNormalization(features)
155
+
156
+ def forward(self, x, mask):
157
+ for layer in self.layers:
158
+ x = layer(x, mask)
159
+ return self.norm(x)
160
+
161
+ class DecoderBlock(nn.Module):
162
+
163
+ def __init__(self, features: int, self_attention_block: MultiHeadAttentionBlock, cross_attention_block: MultiHeadAttentionBlock, feed_forward_block: FeedForwardBlock, dropout: float) -> None:
164
+ super().__init__()
165
+ self.self_attention_block = self_attention_block
166
+ self.cross_attention_block = cross_attention_block
167
+ self.feed_forward_block = feed_forward_block
168
+ self.residual_connections = nn.ModuleList([ResidualConnection(features, dropout) for _ in range(3)])
169
+
170
+ def forward(self, x, encoder_output, src_mask, tgt_mask):
171
+ x = self.residual_connections[0](x, lambda x: self.self_attention_block(x, x, x, tgt_mask))
172
+ x = self.residual_connections[1](x, lambda x: self.cross_attention_block(x, encoder_output, encoder_output, src_mask))
173
+ x = self.residual_connections[2](x, self.feed_forward_block)
174
+ return x
175
+
176
+ class Decoder(nn.Module):
177
+
178
+ def __init__(self, features: int, layers: nn.ModuleList) -> None:
179
+ super().__init__()
180
+ self.layers = layers
181
+ self.norm = LayerNormalization(features)
182
+
183
+ def forward(self, x, encoder_output, src_mask, tgt_mask):
184
+ for layer in self.layers:
185
+ x = layer(x, encoder_output, src_mask, tgt_mask)
186
+ return self.norm(x)
187
+
188
+ class ProjectionLayer(nn.Module):
189
+
190
+ def __init__(self, d_model, vocab_size) -> None:
191
+ super().__init__()
192
+ self.proj = nn.Linear(d_model, vocab_size)
193
+
194
+ def forward(self, x) -> None:
195
+ # (batch, seq_len, d_model) --> (batch, seq_len, vocab_size)
196
+ return self.proj(x)
197
+
198
+ class Transformer(nn.Module):
199
+
200
+ def __init__(self, encoder: Encoder, decoder: Decoder, src_embed: InputEmbeddings, tgt_embed: InputEmbeddings, src_pos: PositionalEncoding, tgt_pos: PositionalEncoding, projection_layer: ProjectionLayer) -> None:
201
+ super().__init__()
202
+ self.encoder = encoder
203
+ self.decoder = decoder
204
+ self.src_embed = src_embed
205
+ self.tgt_embed = tgt_embed
206
+ self.src_pos = src_pos
207
+ self.tgt_pos = tgt_pos
208
+ self.projection_layer = projection_layer
209
+
210
+ def encode(self, src, src_mask):
211
+ # (batch, seq_len, d_model)
212
+ src = self.src_embed(src)
213
+ src = self.src_pos(src)
214
+ return self.encoder(src, src_mask)
215
+
216
+ def decode(self, encoder_output: torch.Tensor, src_mask: torch.Tensor, tgt: torch.Tensor, tgt_mask: torch.Tensor):
217
+ # (batch, seq_len, d_model)
218
+ tgt = self.tgt_embed(tgt)
219
+ tgt = self.tgt_pos(tgt)
220
+ return self.decoder(tgt, encoder_output, src_mask, tgt_mask)
221
+
222
+ def project(self, x):
223
+ # (batch, seq_len, vocab_size)
224
+ return self.projection_layer(x)
225
+
226
+ def build_transformer(src_vocab_size: int, tgt_vocab_size: int, src_seq_len: int, tgt_seq_len: int, d_model: int=512, N: int=6, h: int=8, dropout: float=0.1, d_ff: int=2048) -> Transformer:
227
+ # Create the embedding layers
228
+ src_embed = InputEmbeddings(d_model, src_vocab_size)
229
+ tgt_embed = InputEmbeddings(d_model, tgt_vocab_size)
230
+
231
+ # Create the positional encoding layers
232
+ src_pos = PositionalEncoding(d_model, src_seq_len, dropout)
233
+ tgt_pos = PositionalEncoding(d_model, tgt_seq_len, dropout)
234
+
235
+ # Create the encoder blocks
236
+ encoder_blocks = []
237
+ for _ in range(N):
238
+ encoder_self_attention_block = MultiHeadAttentionBlock(d_model, h, dropout)
239
+ feed_forward_block = FeedForwardBlock(d_model, d_ff, dropout)
240
+ encoder_block = EncoderBlock(d_model, encoder_self_attention_block, feed_forward_block, dropout)
241
+ encoder_blocks.append(encoder_block)
242
+
243
+ # Create the decoder blocks
244
+ decoder_blocks = []
245
+ for _ in range(N):
246
+ decoder_self_attention_block = MultiHeadAttentionBlock(d_model, h, dropout)
247
+ decoder_cross_attention_block = MultiHeadAttentionBlock(d_model, h, dropout)
248
+ feed_forward_block = FeedForwardBlock(d_model, d_ff, dropout)
249
+ decoder_block = DecoderBlock(d_model, decoder_self_attention_block, decoder_cross_attention_block, feed_forward_block, dropout)
250
+ decoder_blocks.append(decoder_block)
251
+
252
+ # Create the encoder and decoder
253
+ encoder = Encoder(d_model, nn.ModuleList(encoder_blocks))
254
+ decoder = Decoder(d_model, nn.ModuleList(decoder_blocks))
255
+
256
+ # Create the projection layer
257
+ projection_layer = ProjectionLayer(d_model, tgt_vocab_size)
258
+
259
+ # Create the transformer
260
+ transformer = Transformer(encoder, decoder, src_embed, tgt_embed, src_pos, tgt_pos, projection_layer)
261
+
262
+ # Initialize the parameters
263
+ for p in transformer.parameters():
264
+ if p.dim() > 1:
265
+ nn.init.xavier_uniform_(p)
266
+
267
+ return transformer
train.py ADDED
@@ -0,0 +1,274 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from model import build_transformer
2
+ from dataset import BilingualDataset, causal_mask
3
+ from config import get_config, get_weights_file_path, latest_weights_file_path
4
+
5
+ import torchtext.datasets as datasets
6
+ import torch
7
+ import torch.nn as nn
8
+ from torch.utils.data import Dataset, DataLoader, random_split
9
+ from torch.optim.lr_scheduler import LambdaLR
10
+
11
+ import warnings
12
+ from tqdm import tqdm
13
+ import os
14
+ from pathlib import Path
15
+
16
+ # Huggingface datasets and tokenizers
17
+ from datasets import load_dataset
18
+ from tokenizers import Tokenizer
19
+ from tokenizers.models import WordLevel
20
+ from tokenizers.trainers import WordLevelTrainer
21
+ from tokenizers.pre_tokenizers import Whitespace
22
+
23
+ import torchmetrics
24
+ from torch.utils.tensorboard import SummaryWriter
25
+
26
+ def greedy_decode(model, source, source_mask, tokenizer_src, tokenizer_tgt, max_len, device):
27
+ sos_idx = tokenizer_tgt.token_to_id('[SOS]')
28
+ eos_idx = tokenizer_tgt.token_to_id('[EOS]')
29
+
30
+ # Precompute the encoder output and reuse it for every step
31
+ encoder_output = model.encode(source, source_mask)
32
+ # Initialize the decoder input with the sos token
33
+ decoder_input = torch.empty(1, 1).fill_(sos_idx).type_as(source).to(device)
34
+ while True:
35
+ if decoder_input.size(1) == max_len:
36
+ break
37
+
38
+ # build mask for target
39
+ decoder_mask = causal_mask(decoder_input.size(1)).type_as(source_mask).to(device)
40
+
41
+ # calculate output
42
+ out = model.decode(encoder_output, source_mask, decoder_input, decoder_mask)
43
+
44
+ # get next token
45
+ prob = model.project(out[:, -1])
46
+ _, next_word = torch.max(prob, dim=1)
47
+ decoder_input = torch.cat(
48
+ [decoder_input, torch.empty(1, 1).type_as(source).fill_(next_word.item()).to(device)], dim=1
49
+ )
50
+
51
+ if next_word == eos_idx:
52
+ break
53
+
54
+ return decoder_input.squeeze(0)
55
+
56
+
57
+ def run_validation(model, validation_ds, tokenizer_src, tokenizer_tgt, max_len, device, print_msg, global_step, writer, num_examples=2):
58
+ model.eval()
59
+ count = 0
60
+
61
+ source_texts = []
62
+ expected = []
63
+ predicted = []
64
+
65
+ try:
66
+ # get the console window width
67
+ with os.popen('stty size', 'r') as console:
68
+ _, console_width = console.read().split()
69
+ console_width = int(console_width)
70
+ except:
71
+ # If we can't get the console width, use 80 as default
72
+ console_width = 80
73
+
74
+ with torch.no_grad():
75
+ for batch in validation_ds:
76
+ count += 1
77
+ encoder_input = batch["encoder_input"].to(device) # (b, seq_len)
78
+ encoder_mask = batch["encoder_mask"].to(device) # (b, 1, 1, seq_len)
79
+
80
+ # check that the batch size is 1
81
+ assert encoder_input.size(
82
+ 0) == 1, "Batch size must be 1 for validation"
83
+
84
+ model_out = greedy_decode(model, encoder_input, encoder_mask, tokenizer_src, tokenizer_tgt, max_len, device)
85
+
86
+ source_text = batch["src_text"][0]
87
+ target_text = batch["tgt_text"][0]
88
+ model_out_text = tokenizer_tgt.decode(model_out.detach().cpu().numpy())
89
+
90
+ source_texts.append(source_text)
91
+ expected.append(target_text)
92
+ predicted.append(model_out_text)
93
+
94
+ # Print the source, target and model output
95
+ print_msg('-'*console_width)
96
+ print_msg(f"{f'SOURCE: ':>12}{source_text}")
97
+ print_msg(f"{f'TARGET: ':>12}{target_text}")
98
+ print_msg(f"{f'PREDICTED: ':>12}{model_out_text}")
99
+
100
+ if count == num_examples:
101
+ print_msg('-'*console_width)
102
+ break
103
+
104
+ if writer:
105
+ # Evaluate the character error rate
106
+ # Compute the char error rate
107
+ metric = torchmetrics.CharErrorRate()
108
+ cer = metric(predicted, expected)
109
+ writer.add_scalar('validation cer', cer, global_step)
110
+ writer.flush()
111
+
112
+ # Compute the word error rate
113
+ metric = torchmetrics.WordErrorRate()
114
+ wer = metric(predicted, expected)
115
+ writer.add_scalar('validation wer', wer, global_step)
116
+ writer.flush()
117
+
118
+ # Compute the BLEU metric
119
+ metric = torchmetrics.BLEUScore()
120
+ bleu = metric(predicted, expected)
121
+ writer.add_scalar('validation BLEU', bleu, global_step)
122
+ writer.flush()
123
+
124
+ def get_all_sentences(ds, lang):
125
+ for item in ds:
126
+ yield item['translation'][lang]
127
+
128
+ def get_or_build_tokenizer(config, ds, lang):
129
+ tokenizer_path = Path(config['tokenizer_file'].format(lang))
130
+ if not Path.exists(tokenizer_path):
131
+ # Most code taken from: https://huggingface.co/docs/tokenizers/quicktour
132
+ tokenizer = Tokenizer(WordLevel(unk_token="[UNK]"))
133
+ tokenizer.pre_tokenizer = Whitespace()
134
+ trainer = WordLevelTrainer(special_tokens=["[UNK]", "[PAD]", "[SOS]", "[EOS]"], min_frequency=2)
135
+ tokenizer.train_from_iterator(get_all_sentences(ds, lang), trainer=trainer)
136
+ tokenizer.save(str(tokenizer_path))
137
+ else:
138
+ tokenizer = Tokenizer.from_file(str(tokenizer_path))
139
+ return tokenizer
140
+
141
+ def get_ds(config):
142
+ # It only has the train split, so we divide it overselves
143
+ ds_raw = load_dataset(f"{config['datasource']}", f"{config['lang_src']}-{config['lang_tgt']}", split='train')
144
+
145
+ # Build tokenizers
146
+ tokenizer_src = get_or_build_tokenizer(config, ds_raw, config['lang_src'])
147
+ tokenizer_tgt = get_or_build_tokenizer(config, ds_raw, config['lang_tgt'])
148
+
149
+ # Keep 90% for training, 10% for validation
150
+ train_ds_size = int(0.9 * len(ds_raw))
151
+ val_ds_size = len(ds_raw) - train_ds_size
152
+ train_ds_raw, val_ds_raw = random_split(ds_raw, [train_ds_size, val_ds_size])
153
+
154
+ train_ds = BilingualDataset(train_ds_raw, tokenizer_src, tokenizer_tgt, config['lang_src'], config['lang_tgt'], config['seq_len'])
155
+ val_ds = BilingualDataset(val_ds_raw, tokenizer_src, tokenizer_tgt, config['lang_src'], config['lang_tgt'], config['seq_len'])
156
+
157
+ # Find the maximum length of each sentence in the source and target sentence
158
+ max_len_src = 0
159
+ max_len_tgt = 0
160
+
161
+ for item in ds_raw:
162
+ src_ids = tokenizer_src.encode(item['translation'][config['lang_src']]).ids
163
+ tgt_ids = tokenizer_tgt.encode(item['translation'][config['lang_tgt']]).ids
164
+ max_len_src = max(max_len_src, len(src_ids))
165
+ max_len_tgt = max(max_len_tgt, len(tgt_ids))
166
+
167
+ print(f'Max length of source sentence: {max_len_src}')
168
+ print(f'Max length of target sentence: {max_len_tgt}')
169
+
170
+
171
+ train_dataloader = DataLoader(train_ds, batch_size=config['batch_size'], shuffle=True)
172
+ val_dataloader = DataLoader(val_ds, batch_size=1, shuffle=True)
173
+
174
+ return train_dataloader, val_dataloader, tokenizer_src, tokenizer_tgt
175
+
176
+ def get_model(config, vocab_src_len, vocab_tgt_len):
177
+ model = build_transformer(vocab_src_len, vocab_tgt_len, config["seq_len"], config['seq_len'], d_model=config['d_model'])
178
+ return model
179
+
180
+ def train_model(config):
181
+ # Define the device
182
+ device = "cuda" if torch.cuda.is_available() else "mps" if torch.has_mps or torch.backends.mps.is_available() else "cpu"
183
+ print("Using device:", device)
184
+ if (device == 'cuda'):
185
+ print(f"Device name: {torch.cuda.get_device_name(device.index)}")
186
+ print(f"Device memory: {torch.cuda.get_device_properties(device.index).total_memory / 1024 ** 3} GB")
187
+ elif (device == 'mps'):
188
+ print(f"Device name: <mps>")
189
+ else:
190
+ print("NOTE: If you have a GPU, consider using it for training.")
191
+ print(" On a Windows machine with NVidia GPU, check this video: https://www.youtube.com/watch?v=GMSjDTU8Zlc")
192
+ print(" On a Mac machine, run: pip3 install --pre torch torchvision torchaudio torchtext --index-url https://download.pytorch.org/whl/nightly/cpu")
193
+ device = torch.device(device)
194
+
195
+ # Make sure the weights folder exists
196
+ Path(f"{config['datasource']}_{config['model_folder']}").mkdir(parents=True, exist_ok=True)
197
+
198
+ train_dataloader, val_dataloader, tokenizer_src, tokenizer_tgt = get_ds(config)
199
+ model = get_model(config, tokenizer_src.get_vocab_size(), tokenizer_tgt.get_vocab_size()).to(device)
200
+ # Tensorboard
201
+ writer = SummaryWriter(config['experiment_name'])
202
+
203
+ optimizer = torch.optim.Adam(model.parameters(), lr=config['lr'], eps=1e-9)
204
+
205
+ # If the user specified a model to preload before training, load it
206
+ initial_epoch = 0
207
+ global_step = 0
208
+ preload = config['preload']
209
+ model_filename = latest_weights_file_path(config) if preload == 'latest' else get_weights_file_path(config, preload) if preload else None
210
+ if model_filename:
211
+ print(f'Preloading model {model_filename}')
212
+ state = torch.load(model_filename)
213
+ model.load_state_dict(state['model_state_dict'])
214
+ initial_epoch = state['epoch'] + 1
215
+ optimizer.load_state_dict(state['optimizer_state_dict'])
216
+ global_step = state['global_step']
217
+ else:
218
+ print('No model to preload, starting from scratch')
219
+
220
+ loss_fn = nn.CrossEntropyLoss(ignore_index=tokenizer_src.token_to_id('[PAD]'), label_smoothing=0.1).to(device)
221
+
222
+ for epoch in range(initial_epoch, config['num_epochs']):
223
+ torch.cuda.empty_cache()
224
+ model.train()
225
+ batch_iterator = tqdm(train_dataloader, desc=f"Processing Epoch {epoch:02d}")
226
+ for batch in batch_iterator:
227
+
228
+ encoder_input = batch['encoder_input'].to(device) # (b, seq_len)
229
+ decoder_input = batch['decoder_input'].to(device) # (B, seq_len)
230
+ encoder_mask = batch['encoder_mask'].to(device) # (B, 1, 1, seq_len)
231
+ decoder_mask = batch['decoder_mask'].to(device) # (B, 1, seq_len, seq_len)
232
+
233
+ # Run the tensors through the encoder, decoder and the projection layer
234
+ encoder_output = model.encode(encoder_input, encoder_mask) # (B, seq_len, d_model)
235
+ decoder_output = model.decode(encoder_output, encoder_mask, decoder_input, decoder_mask) # (B, seq_len, d_model)
236
+ proj_output = model.project(decoder_output) # (B, seq_len, vocab_size)
237
+
238
+ # Compare the output with the label
239
+ label = batch['label'].to(device) # (B, seq_len)
240
+
241
+ # Compute the loss using a simple cross entropy
242
+ loss = loss_fn(proj_output.view(-1, tokenizer_tgt.get_vocab_size()), label.view(-1))
243
+ batch_iterator.set_postfix({"loss": f"{loss.item():6.3f}"})
244
+
245
+ # Log the loss
246
+ writer.add_scalar('train loss', loss.item(), global_step)
247
+ writer.flush()
248
+
249
+ # Backpropagate the loss
250
+ loss.backward()
251
+
252
+ # Update the weights
253
+ optimizer.step()
254
+ optimizer.zero_grad(set_to_none=True)
255
+
256
+ global_step += 1
257
+
258
+ # Run validation at the end of every epoch
259
+ run_validation(model, val_dataloader, tokenizer_src, tokenizer_tgt, config['seq_len'], device, lambda msg: batch_iterator.write(msg), global_step, writer)
260
+
261
+ # Save the model at the end of every epoch
262
+ model_filename = get_weights_file_path(config, f"{epoch:02d}")
263
+ torch.save({
264
+ 'epoch': epoch,
265
+ 'model_state_dict': model.state_dict(),
266
+ 'optimizer_state_dict': optimizer.state_dict(),
267
+ 'global_step': global_step
268
+ }, model_filename)
269
+
270
+
271
+ if __name__ == '__main__':
272
+ warnings.filterwarnings("ignore")
273
+ config = get_config()
274
+ train_model(config)