Lwasinam commited on
Commit
22ca2be
1 Parent(s): feebcbc

Upload 5 files

Browse files
Files changed (5) hide show
  1. config.py +26 -0
  2. dataset.py +131 -0
  3. inference.py +155 -0
  4. model.py +411 -0
  5. train.py +374 -0
config.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+
3
+ def get_config():
4
+ return {
5
+ "batch_size":1,
6
+ "num_epochs": 20,
7
+ "lr": 1e-5,
8
+ "seq_len": 261,
9
+ "d_model": 768,
10
+ "lang_src": "en",
11
+ "lang_tgt": "it",
12
+ "model_folder": "weights",
13
+ "model_basename": "tmodel_",
14
+ "preload": None,
15
+ "tokenizer_file": "tokenizer.json",
16
+ "experiment_name": "runs/tmodel",
17
+ 'project_name': 'proj1'
18
+ }
19
+
20
+ def get_weights_file_path(config, epoch: str):
21
+ model_folder = config["model_folder"]
22
+ model_basename = config["model_basename"]
23
+ model_filename = f"{model_basename}"
24
+ return str(Path('.') / model_folder / model_filename)
25
+
26
+
dataset.py ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from torch.utils.data import IterableDataset, Dataset
4
+ from transformers import ViTFeatureExtractor
5
+ from transformers import ViTImageProcessor
6
+ from io import BytesIO
7
+ from base64 import b64decode
8
+ from PIL import Image,ImageFile
9
+ import base64
10
+ import itertools
11
+ from concurrent.futures import ThreadPoolExecutor
12
+ from functools import partial
13
+ import io
14
+ import urllib
15
+ import random
16
+ import PIL.Image
17
+ from datasets import load_dataset
18
+ from datasets.utils.file_utils import get_datasets_user_agent
19
+
20
+
21
+ USER_AGENT = get_datasets_user_agent()
22
+
23
+ # import model
24
+ model_id = 'google/vit-base-patch16-224-in21k'
25
+ feature_extractor = ViTFeatureExtractor.from_pretrained(
26
+ model_id
27
+ )
28
+ class BilingualDataset(Dataset):
29
+
30
+ def __init__(self, ds,tokenizer_tgt, seq_len):
31
+ super().__init__()
32
+ self.seq_len = seq_len
33
+ ImageFile.LOAD_TRUNCATED_IMAGES = True
34
+ self.ds = ds
35
+ self.tokenizer_tgt = tokenizer_tgt
36
+ # self.tgt_lang = tgt_lang
37
+ self.processor = ViTImageProcessor.from_pretrained('google/vit-base-patch16-224-in21k')
38
+ self.sos_token = torch.tensor([tokenizer_tgt.convert_tokens_to_ids("[SOS]")], dtype=torch.int64)
39
+ self.eos_token = torch.tensor([tokenizer_tgt.convert_tokens_to_ids("[EOS]")], dtype=torch.int64)
40
+ self.pad_token = torch.tensor([tokenizer_tgt.convert_tokens_to_ids("[PAD]")], dtype=torch.int64)
41
+
42
+ def __len__(self):
43
+ return len(self.ds)
44
+ # def __getitem__(self):
45
+ # pass
46
+
47
+ def __getitem__(self, idx):
48
+ data_pair = self.ds[idx]
49
+
50
+
51
+ src_image = data_pair['image_base64_str']
52
+ tgt_text = data_pair['outputs']
53
+
54
+
55
+
56
+
57
+
58
+
59
+ src_image = Image.open(BytesIO(b64decode(''.join(src_image))))
60
+ if src_image.mode != 'RGB':
61
+ src_image = src_image.convert('RGB')
62
+ src_image = self.processor(src_image, return_tensors='pt')
63
+
64
+
65
+
66
+
67
+ # Transform the text into tokens
68
+ dec_input_tokens = self.tokenizer_tgt.encode(tgt_text)
69
+
70
+ # # Add sos, eos and padding to each sentence
71
+ # enc_num_padding_tokens = self.seq_len - len(enc_input_tokens) - 2 # We will add <s> and </s>
72
+ # We will only add <s>, and </s> only on the label
73
+ dec_input_tokens = dec_input_tokens[:self.seq_len-1]
74
+ dec_num_padding_tokens = self.seq_len - len(dec_input_tokens) -1
75
+
76
+ # Make sure the number of padding tokens is not negative. If it is, the sentence is too long
77
+ if dec_num_padding_tokens < 0:
78
+ raise ValueError("Sentence is too long")
79
+
80
+ # # Add <s> and </s> token
81
+ # encoder_input = torch.cat(
82
+ # [
83
+ # self.sos_token,
84
+ # torch.tensor(enc_input_tokens, dtype=torch.int64),
85
+ # self.eos_token,
86
+ # torch.tensor([self.pad_token] * enc_num_padding_tokens, dtype=torch.int64),
87
+ # ],
88
+ # dim=0,
89
+ # )
90
+
91
+ # Add only <s> token
92
+ decoder_input = torch.cat(
93
+ [
94
+ self.sos_token,
95
+ torch.tensor(dec_input_tokens, dtype=torch.int64),
96
+ torch.tensor([self.pad_token] * dec_num_padding_tokens, dtype=torch.int64),
97
+ ],
98
+ dim=0,
99
+ )
100
+
101
+ # Add only </s> token
102
+ label = torch.cat(
103
+ [
104
+ torch.tensor(dec_input_tokens, dtype=torch.int64),
105
+ self.eos_token,
106
+ torch.tensor([self.pad_token] * dec_num_padding_tokens, dtype=torch.int64),
107
+ ],
108
+ dim=0,
109
+ )
110
+
111
+
112
+ assert decoder_input.size(0) == self.seq_len
113
+ assert label.size(0) == self.seq_len
114
+
115
+ return {
116
+ 'encoder_input' : src_image['pixel_values'].squeeze(0).squeeze(0).squeeze(0).squeeze(0).squeeze(0), # (seq_len)
117
+ 'decoder_input' : decoder_input, # (seq_len)
118
+ ## encoder mask not used :)
119
+ "encoder_mask" : (torch.cat((torch.ones(197,),torch.zeros(63),),)).unsqueeze(0).unsqueeze(0), # (1, 1, seq_len)
120
+ "decoder_mask" : (decoder_input != self.pad_token).unsqueeze(0).int() & causal_mask(decoder_input.size(0)), # (1, seq_len) & (1, seq_len, seq_len),
121
+ "label" : label,
122
+ # "src_text": src_text,
123
+ "tgt_text" : tgt_text
124
+
125
+ }
126
+ # yield encoder_input, decoder_input, encoder_mask, decoder_mask, label
127
+
128
+
129
+ def causal_mask(size):
130
+ mask = torch.triu(torch.ones((1, size, size)), diagonal=1).type(torch.int)
131
+ return mask == 0
inference.py ADDED
@@ -0,0 +1,155 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import transformers
4
+ from torch.utils.data import Dataset
5
+ from transformers import ViTFeatureExtractor
6
+ from io import BytesIO
7
+ from base64 import b64decode
8
+ from PIL import Image
9
+ from accelerate import Accelerator
10
+ import base64
11
+ from config import get_config
12
+ from pathlib import Path
13
+ from tokenizers import Tokenizer
14
+ from tokenizers.models import WordLevel
15
+ from tokenizers.trainers import WordLevelTrainer
16
+ from tokenizers.pre_tokenizers import Whitespace
17
+ from model import build_transformer
18
+ import torch.nn.functional as F
19
+ from transformers import GPT2TokenizerFast
20
+
21
+ def process(model,image, tokenizer, device):
22
+ image = get_image(image)
23
+ model.eval()
24
+ with torch.no_grad():
25
+ encoder_input = image.unsqueeze(0).to(device) # (b, seq_len)
26
+ # decoder_input = batch['decoder_input'].to(device) # (B, seq_len)
27
+ # encoder_mask = batch['encoder_mask'].to(device) # (B, 1, 1, seq_len)
28
+ # decoder_mask = batch['decoder_mask'].to(device) # (B, 1, seq_len, seq_len)
29
+
30
+ model_out = greedy_decode(model, encoder_input, None, tokenizer, 196,device)
31
+ model_text = tokenizer.decode(model_out.detach().cpu().numpy())
32
+ print(model_text)
33
+
34
+
35
+
36
+
37
+
38
+
39
+
40
+
41
+
42
+ # get image prompt
43
+ def get_image(image):
44
+ # import model
45
+ model_id = 'google/vit-base-patch16-224-in21k'
46
+ feature_extractor = ViTFeatureExtractor.from_pretrained(
47
+ model_id
48
+ )
49
+
50
+
51
+ image = Image.open(BytesIO(b64decode(''.join(image))))
52
+
53
+ if image.mode != 'RGB':
54
+ image = image.convert('RGB')
55
+
56
+ enc_input = feature_extractor(
57
+ image,
58
+ return_tensors='pt'
59
+ )
60
+
61
+ return enc_input['pixel_values'].squeeze(0).squeeze(0).squeeze(0).squeeze(0).squeeze(0)
62
+
63
+
64
+
65
+
66
+ #get tokenizer
67
+ def get_or_build_tokenizer(config):
68
+ tokenizer = GPT2TokenizerFast.from_pretrained("openai-community/gpt2", unk_token ='[UNK]', bos_token = '[SOS]', eos_token = '[EOS]' , pad_token = '[PAD]')
69
+ return tokenizer
70
+
71
+
72
+ def causal_mask(size):
73
+ mask = torch.triu(torch.ones((1, size, size)), diagonal=1).type(torch.int)
74
+ return mask == 0
75
+
76
+
77
+ # get model
78
+ def get_model(config, vocab_tgt_len):
79
+ model = build_transformer(vocab_tgt_len, config['seq_len'], d_model=config['d_model'])
80
+ return model
81
+
82
+ # greedy decode
83
+ def greedy_decode(model, source, source_mask, tokenizer_tgt, max_len, device):
84
+ sos_idx = tokenizer_tgt.convert_tokens_to_ids('[SOS]')
85
+ eos_idx = tokenizer_tgt.convert_tokens_to_ids('[EOS]')
86
+
87
+ # Precompute the encoder output and reuse it for every step
88
+ encoder_output = model.encode(source, None)
89
+
90
+ # Initialize the decoder input with the sos token
91
+ decoder_input = torch.empty(1, 1).fill_(sos_idx).long().to(device)
92
+ while True:
93
+ if decoder_input.size(1) == max_len:
94
+ break
95
+
96
+ # build mask for target
97
+ decoder_mask = causal_mask(decoder_input.size(1)).long().to(device)
98
+
99
+
100
+ # calculate output
101
+ out = model.decode(encoder_output, source_mask, decoder_input, decoder_mask)
102
+ # print(f'out: {out.shape}')
103
+
104
+ # Get next token probabilities with temperature applied
105
+ logits = model.project(out[:, -1])
106
+ probabilities = F.softmax(logits, dim=-1)
107
+
108
+ # Greedily select the next word
109
+ next_word = torch.argmax(probabilities, dim=1)
110
+
111
+ # Append next word
112
+ decoder_input = torch.cat([decoder_input, next_word.unsqueeze(0)], dim=1)
113
+ # # get next token
114
+ # prob = model.project(out[:, -1])
115
+ # _, next_word = torch.max(prob, dim=1)
116
+ # # print(f'prob: {prob.shape}')
117
+ # decoder_input = torch.cat(
118
+ # [decoder_input, torch.empty(1, 1).long().fill_(next_word.item()).to(device)], dim=1
119
+ # )
120
+
121
+ if next_word.item() == eos_idx:
122
+ break
123
+
124
+ return decoder_input.squeeze(0)
125
+
126
+ def image_base64():
127
+
128
+
129
+ with open('C:/AI/projects/vision_model_pretrained/validation/content/memory_image_23330.jpg', 'rb') as image_file:
130
+ base64_bytes = base64.b64encode(image_file.read())
131
+
132
+
133
+ base64_string = base64_bytes.decode()
134
+ return base64_string
135
+
136
+
137
+ def start():
138
+ print('start')
139
+ accelerator = Accelerator()
140
+ device = accelerator.device
141
+
142
+ config = get_config()
143
+ tokenizer = get_or_build_tokenizer(config)
144
+ model = get_model(config, len(tokenizer))
145
+ model = accelerator.prepare(model)
146
+ accelerator.load_state('C:/AI/projects/vision_model_pretrained/Vision_Model_pretrained/models/vision_model_04')
147
+
148
+ image = image_base64()
149
+
150
+
151
+
152
+ process(model, image, tokenizer, device)
153
+
154
+ start()
155
+
model.py ADDED
@@ -0,0 +1,411 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import math
4
+ import matplotlib.pyplot as plt
5
+ import numpy as np
6
+ from PIL import Image
7
+ from torchvision.transforms.functional import to_pil_image, to_tensor
8
+ import time
9
+ import numpy as np
10
+ from matplotlib.image import imread
11
+ from transformers import ViTFeatureExtractor
12
+ from io import BytesIO
13
+ from base64 import b64decode
14
+ import base64
15
+ from transformers import ViTImageProcessor, ViTModel
16
+ ## code from @jankrepl on github
17
+
18
+ class PretrainedVit():
19
+ def __init__(self):
20
+
21
+ self.model = ViTModel.from_pretrained('google/vit-base-patch16-224-in21k')
22
+ def forward(self, x):
23
+
24
+ self.model.to(torch.device("cuda" if torch.cuda.is_available() else "cpu"))
25
+ self.model.config.output_hidden_states = True
26
+ outputs = self.model(x)
27
+ # print(outputs)
28
+ last_hidden_states = outputs.hidden_states
29
+ return list(last_hidden_states)
30
+
31
+ class PatchEmbed(nn.Module):
32
+ """Split image into patches and then embed them.
33
+
34
+ Parameters
35
+ ----------
36
+ img_size : int
37
+ Size of the image (it is a square).
38
+
39
+ patch_size : int
40
+ Size of the patch (it is a square).
41
+
42
+ in_chans : int
43
+ Number of input channels.
44
+
45
+ embed_dim : int
46
+ The emmbedding dimension.
47
+
48
+ Attributes
49
+ ----------
50
+ n_patches : int
51
+ Number of patches inside of our image.
52
+
53
+ proj : nn.Conv2d
54
+ Convolutional layer that does both the splitting into patches
55
+ and their embedding.
56
+ """
57
+ def __init__(self, img_size, patch_size, in_chans=3, embed_dim=1024, num_registers = 6):
58
+ super().__init__()
59
+ self.img_size = img_size
60
+ self.patch_size = patch_size
61
+ self.norm = RMSNorm()
62
+ self.n_patches = (img_size // patch_size) ** 2
63
+ self.pos_embed = nn.Parameter(
64
+ torch.zeros(1, self.n_patches+1+num_registers, embed_dim)
65
+ )
66
+ # Adding CLS token as a learnable parameter
67
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
68
+ self.register_token = nn.Parameter(torch.zeros(num_registers, embed_dim))
69
+
70
+ self.proj = nn.Conv2d(
71
+ in_chans,
72
+ embed_dim,
73
+ kernel_size=patch_size,
74
+ stride=patch_size,
75
+ )
76
+
77
+ def forward(self, x):
78
+ """Run forward pass.
79
+
80
+ Parameters
81
+ ----------
82
+ x : torch.Tensor
83
+ Shape `(n_samples, in_chans, img_size, img_size)`.
84
+
85
+ Returns
86
+ -------
87
+ torch.Tensor
88
+ Shape `(n_samples, n_patches, embed_dim)`.
89
+ """
90
+ x = self.proj(x) # (n_samples, embed_dim, n_patches ** 0.5, n_patches ** 0.5)
91
+ x = x.flatten(2) # (n_samples, embed_dim, n_patches)
92
+ x = x.transpose(1, 2) # (n_samples, n_patches, embed_dim)
93
+ batch_size = x.shape[0]
94
+
95
+
96
+ cls_tokens = self.cls_token.expand(batch_size, -1, -1) # Expand CLS tokens for the batch
97
+ x = torch.cat([cls_tokens, x], dim=1)
98
+
99
+ # x: (n_samples, n_patches + 1 + num_registers, embed_dimension) add register tokens
100
+ register_tokens = self.register_token.unsqueeze(0).expand(batch_size, -1, -1)
101
+ x = torch.cat([x, register_tokens], dim=1)
102
+ X = self.norm(x)
103
+ x = x + self.pos_embed # Learnable pos embed -> (n_samples, n_patches_embed_dim)
104
+
105
+ return x
106
+
107
+
108
+ ## not used
109
+ class RMSNorm(nn.Module):
110
+ def __init__(self, dim: int = 1024, eps: float = 1e-6):
111
+ super().__init__()
112
+ self.eps = eps
113
+ self.dim = dim
114
+ # The gamma parameter
115
+ self.weight = nn.Parameter(torch.ones(self.dim))
116
+
117
+ def _norm(self, x: torch.Tensor):
118
+ # (B, Seq_Len, Dim) * (B, Seq_Len, 1) = (B, Seq_Len, Dim)
119
+ # rsqrt: 1 / sqrt(x)
120
+ return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
121
+
122
+ def forward(self, x: torch.Tensor):
123
+ # (Dim) * (B, Seq_Len, Dim) = (B, Seq_Len, Dim)
124
+ return self.weight * self._norm(x.float()).type_as(x)
125
+
126
+ class LayerNormalization(nn.Module):
127
+
128
+ def __init__(self, eps:float=1e-12) -> None:
129
+ super().__init__()
130
+ self.eps = eps
131
+ self.alpha = nn.Parameter(torch.ones(1)) # alpha is a learnable parameter
132
+ self.bias = nn.Parameter(torch.zeros(1)) # bias is a learnable parameter
133
+
134
+ def forward(self, x):
135
+ # x: (batch, seq_len, hidden_size)
136
+ # Keep the dimension for broadcasting
137
+ mean = x.mean(dim = -1, keepdim = True) # (batch, seq_len, 1)
138
+ # Keep the dimension for broadcasting
139
+ std = x.std(dim = -1, keepdim = True) # (batch, seq_len, 1)
140
+ # eps is to prevent dividing by zero or when std is very small
141
+ # print(f'mean shape {mean.squeeze(-1).shape}')
142
+ return self.alpha * (x - mean) / (std + self.eps) + self.bias
143
+
144
+ class FeedForwardBlock(nn.Module):
145
+
146
+ def __init__(self, d_model: int, d_ff: int, dropout: float) -> None:
147
+ super().__init__()
148
+ self.linear_1 = nn.Linear(d_model, d_ff) # w1 and b1
149
+ self.dropout = nn.Dropout(dropout)
150
+ self.linear_2 = nn.Linear(d_ff, d_model) # w2 and b2
151
+
152
+ def forward(self, x):
153
+ # (batch, seq_len, d_model) --> (batch, seq_len, d_ff) --> (batch, seq_len, d_model)
154
+ return self.linear_2(self.dropout(torch.relu(self.linear_1(x))))
155
+
156
+ class InputEmbeddings(nn.Module):
157
+
158
+ def __init__(self, d_model: int, vocab_size: int) -> None:
159
+ super().__init__()
160
+ self.d_model = d_model
161
+ self.vocab_size = vocab_size
162
+ self.embedding = nn.Embedding(vocab_size, d_model)
163
+
164
+ def forward(self, x):
165
+ # (batch, seq_len) --> (batch, seq_len, d_model)
166
+ # Multiply by sqrt(d_model) to scale the embeddings according to the paper
167
+ return self.embedding(x) * math.sqrt(self.d_model)
168
+
169
+ class PositionalEncoding(nn.Module):
170
+
171
+ def __init__(self, d_model: int, seq_len: int, dropout: float) -> None:
172
+ super().__init__()
173
+ self.d_model = d_model
174
+ self.seq_len = seq_len
175
+ self.dropout = nn.Dropout(dropout)
176
+ # Create a matrix of shape (seq_len, d_model)
177
+ pe = torch.zeros(seq_len, d_model)
178
+ # Create a vector of shape (seq_len)
179
+ position = torch.arange(0, seq_len, dtype=torch.float).unsqueeze(1) # (seq_len, 1)
180
+ # Create a vector of shape (d_model)
181
+ div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)) # (d_model / 2)
182
+ # Apply sine to even indices
183
+ pe[:, 0::2] = torch.sin(position * div_term) # sin(position * (10000 ** (2i / d_model))
184
+ # Apply cosine to odd indices
185
+ pe[:, 1::2] = torch.cos(position * div_term) # cos(position * (10000 ** (2i / d_model))
186
+ # Add a batch dimension to the positional encoding
187
+ pe = pe.unsqueeze(0) # (1, seq_len, d_model)
188
+ # Register the positional encoding as a buffer
189
+ self.register_buffer('pe', pe)
190
+
191
+ def forward(self, x):
192
+ x = x + (self.pe[:, :x.shape[1], :]).requires_grad_(False) # (batch, seq_len, d_model)
193
+ return self.dropout(x)
194
+
195
+ class ResidualConnection(nn.Module):
196
+
197
+ def __init__(self, dropout: float) -> None:
198
+ super().__init__()
199
+ self.dropout = nn.Dropout(dropout)
200
+ self.norm = LayerNormalization()
201
+
202
+ def forward(self, x, sublayer):
203
+ return x + self.dropout(sublayer(self.norm(x)))
204
+
205
+ class MultiHeadAttentionBlock(nn.Module):
206
+
207
+ def __init__(self, d_model: int, h: int, dropout: float) -> None:
208
+ super().__init__()
209
+ self.d_model = d_model # Embedding vector size
210
+ self.h = h # Number of heads
211
+ # Make sure d_model is divisible by h
212
+ assert d_model % h == 0, "d_model is not divisible by h"
213
+
214
+ self.d_k = d_model // h # Dimension of vector seen by each head
215
+ self.w_q = nn.Linear(d_model, d_model) # Wq
216
+ self.w_k = nn.Linear(d_model, d_model) # Wk
217
+ self.w_v = nn.Linear(d_model, d_model) # Wv
218
+ self.w_o = nn.Linear(d_model, d_model) # Wo
219
+ self.dropout = nn.Dropout(dropout)
220
+
221
+ @staticmethod
222
+ def attention(query, key, value, mask, dropout: nn.Dropout):
223
+ d_k = query.shape[-1]
224
+ # Just apply the formula from the paper
225
+ # (batch, h, seq_len, d_k) --> (batch, h, seq_len, seq_len)
226
+
227
+ attention_scores = (query @ key.transpose(-2, -1)) / math.sqrt(d_k)
228
+
229
+
230
+ if mask is not None:
231
+ # Write a very low value (indicating -inf) to the positions where mask == 0
232
+ attention_scores.masked_fill_(mask == 0, -1e9)
233
+ attention_scores = attention_scores.softmax(dim=-1) # (batch, h, seq_len, seq_len) # Apply softmax
234
+ if dropout is not None:
235
+ attention_scores = dropout(attention_scores)
236
+ # (batch, h, seq_len, seq_len) --> (batch, h, seq_len, d_k)
237
+ # return attention scores which can be used for visualization
238
+
239
+ # attention_viz(attention_scores)
240
+ return (attention_scores @ value), attention_scores
241
+
242
+ def forward(self, q, k, v, mask, is_cross=False):
243
+ query = self.w_q(q) # (batch, seq_len, d_model) --> (batch, seq_len, d_model)
244
+ key = self.w_k(k) # (batch, seq_len, d_model) --> (batch, seq_len, d_model)
245
+ value = self.w_v(v) # (batch, seq_len, d_model) --> (batch, seq_len, d_model)
246
+
247
+ # (batch, seq_len, d_model) --> (batch, seq_len, h, d_k) --> (batch, h, seq_len, d_k)
248
+ query = query.view(query.shape[0], query.shape[1], self.h, self.d_k).transpose(1, 2)
249
+ key = key.view(key.shape[0], key.shape[1], self.h, self.d_k).transpose(1, 2)
250
+ value = value.view(value.shape[0], value.shape[1], self.h, self.d_k).transpose(1, 2)
251
+
252
+ # Calculate attention
253
+ x, self.attention_scores = MultiHeadAttentionBlock.attention(query, key, value, mask, self.dropout)
254
+
255
+ if is_cross:
256
+ attention_viz(self.attention_scores)
257
+ # Combine all the heads together
258
+ # (batch, h, seq_len, d_k) --> (batch, seq_len, h, d_k) --> (batch, seq_len, d_model)
259
+ x = x.transpose(1, 2).contiguous().view(x.shape[0], -1, self.h * self.d_k)
260
+
261
+ # Multiply by Wo
262
+ # (batch, seq_len, d_model) --> (batch, seq_len, d_model)
263
+ return self.w_o(x)
264
+
265
+ class EncoderBlock(nn.Module):
266
+
267
+ def __init__(self, self_attention_block: MultiHeadAttentionBlock, feed_forward_block: FeedForwardBlock, dropout: float, layer: int ) -> None:
268
+ super().__init__()
269
+ self.self_attention_block = self_attention_block
270
+ self.feed_forward_block = feed_forward_block
271
+ self.residual_connections = nn.ModuleList([ResidualConnection(dropout) for _ in range(2)])
272
+ self.layer = layer
273
+
274
+ def forward(self, x, src_mask, index):
275
+ # print(x.shape)
276
+ # print(self.layer)
277
+
278
+ out = x[11]
279
+ # out = self.residual_connections[1](out, self.feed_forward_block)
280
+ return out
281
+
282
+ class Encoder(nn.Module):
283
+
284
+ def __init__(self, layers: nn.ModuleList) -> None:
285
+ super().__init__()
286
+ self.layers = layers
287
+ self.norm = LayerNormalization()
288
+
289
+ def forward(self, x, mask):
290
+ for index, layer in enumerate(self.layers):
291
+ # print(index)
292
+ x = layer(x, mask, index)
293
+ break
294
+ return self.norm(x)
295
+
296
+ class DecoderBlock(nn.Module):
297
+
298
+ def __init__(self, self_attention_block: MultiHeadAttentionBlock, cross_attention_block: MultiHeadAttentionBlock, feed_forward_block: FeedForwardBlock, dropout: float) -> None:
299
+ super().__init__()
300
+ self.self_attention_block = self_attention_block
301
+ self.cross_attention_block = cross_attention_block
302
+ self.feed_forward_block = feed_forward_block
303
+ self.residual_connections = nn.ModuleList([ResidualConnection(dropout) for _ in range(3)])
304
+
305
+ def forward(self, x, encoder_output, src_mask, tgt_mask):
306
+ x = self.residual_connections[0](x, lambda x: self.self_attention_block(x, x, x, tgt_mask))
307
+ x = self.residual_connections[1](x, lambda x: self.cross_attention_block(x, encoder_output, encoder_output, src_mask))
308
+ x = self.residual_connections[2](x, self.feed_forward_block)
309
+
310
+ return x
311
+
312
+ class Decoder(nn.Module):
313
+
314
+ def __init__(self, layers: nn.ModuleList) -> None:
315
+ super().__init__()
316
+ self.layers = layers
317
+ self.norm = LayerNormalization()
318
+
319
+ def forward(self, x, encoder_output, src_mask, tgt_mask):
320
+ for layer in self.layers:
321
+ x = layer(x, encoder_output, src_mask, tgt_mask)
322
+ return self.norm(x)
323
+
324
+ class ProjectionLayer(nn.Module):
325
+
326
+ def __init__(self, d_model, vocab_size) -> None:
327
+ super().__init__()
328
+ self.proj = nn.Linear(d_model, vocab_size)
329
+
330
+ def forward(self, x) -> None:
331
+ # (batch, seq_len, d_model) --> (batch, seq_len, vocab_size)
332
+ return torch.log_softmax(self.proj(x), dim = -1)
333
+
334
+ class Transformer(nn.Module):
335
+
336
+ def __init__(self, encoder: Encoder, decoder: Decoder, tgt_embed: InputEmbeddings, tgt_pos: PositionalEncoding, projection_layer: ProjectionLayer, att: PretrainedVit) -> None:
337
+ super().__init__()
338
+ self.encoder = encoder
339
+ self.decoder = decoder
340
+ # self.src_embed = src_embed
341
+ self.tgt_embed = tgt_embed
342
+ # self.src_pos = src_pos
343
+ self.tgt_pos = tgt_pos
344
+ self.projection_layer = projection_layer
345
+ self.patch_embed = PatchEmbed(img_size=224, patch_size=14)
346
+ self.att = att
347
+
348
+ def encode(self, src, src_mask):
349
+ # (batch, seq_len, d_model)
350
+ attention_list = self.att.forward(src)
351
+ # src = self.src_pos(src)
352
+ return self.encoder(attention_list[1:], src_mask)
353
+
354
+ def decode(self, encoder_output: torch.Tensor, src_mask: torch.Tensor, tgt: torch.Tensor, tgt_mask: torch.Tensor):
355
+ # (batch, seq_len, d_model)
356
+
357
+ tgt = self.tgt_embed(tgt)
358
+ tgt = self.tgt_pos(tgt)
359
+ return self.decoder(tgt, encoder_output, src_mask, tgt_mask)
360
+
361
+ def project(self, x):
362
+ # (batch, seq_len, vocab_size)
363
+ return self.projection_layer(x)
364
+
365
+ def build_transformer(tgt_vocab_size: int, tgt_seq_len: int, d_model: int=768, N: int=10, h: int=12, dropout: float=0.1, d_ff: int=3072) -> Transformer:
366
+ # Create the embedding layers
367
+
368
+ tgt_embed = InputEmbeddings(d_model, tgt_vocab_size)
369
+
370
+ # Create the positional encoding layers
371
+ # src_pos = PositionalEncoding(d_model, src_seq_len, dropout)
372
+ tgt_pos = PositionalEncoding(d_model, tgt_seq_len, dropout)
373
+
374
+ #attention from pretrained vit
375
+ att = PretrainedVit()
376
+
377
+
378
+ # Create the encoder blocks
379
+ encoder_blocks = []
380
+ for _ in range(N):
381
+ print()
382
+ encoder_self_attention_block = MultiHeadAttentionBlock(d_model, h, dropout)
383
+ feed_forward_block = FeedForwardBlock(d_model, d_ff, dropout)
384
+ encoder_block = EncoderBlock(encoder_self_attention_block, feed_forward_block, dropout, _)
385
+ encoder_blocks.append(encoder_block)
386
+
387
+ # Create the decoder blocks
388
+ decoder_blocks = []
389
+ for _ in range(N):
390
+ decoder_self_attention_block = MultiHeadAttentionBlock(d_model, h, dropout)
391
+ decoder_cross_attention_block = MultiHeadAttentionBlock(d_model, h, dropout)
392
+ feed_forward_block = FeedForwardBlock(d_model, d_ff, dropout)
393
+ decoder_block = DecoderBlock(decoder_self_attention_block, decoder_cross_attention_block, feed_forward_block, dropout)
394
+ decoder_blocks.append(decoder_block)
395
+
396
+ # Create the encoder and decoder
397
+ encoder = Encoder(nn.ModuleList(encoder_blocks))
398
+ decoder = Decoder(nn.ModuleList(decoder_blocks))
399
+
400
+ # Create the projection layer
401
+ projection_layer = ProjectionLayer(d_model, tgt_vocab_size)
402
+
403
+ # Create the transformer
404
+ transformer = Transformer(encoder, decoder, tgt_embed, tgt_pos, projection_layer, att)
405
+
406
+ # Initialize the parameters
407
+ for p in transformer.parameters():
408
+ if p.dim() > 1:
409
+ nn.init.xavier_uniform_(p)
410
+
411
+ return transformer
train.py ADDED
@@ -0,0 +1,374 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from model import build_transformer
2
+ from dataset import BilingualDataset, causal_mask
3
+ from config import get_config, get_weights_file_path
4
+
5
+
6
+ import datasets
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+ from torch.utils.data import IterableDataset, DataLoader, random_split
11
+ from torch.optim.lr_scheduler import LambdaLR
12
+
13
+ import warnings
14
+ from tqdm import tqdm
15
+ import os
16
+ from pathlib import Path
17
+
18
+ # Huggingface datasets and tokenizers
19
+ from datasets import load_dataset
20
+ from tokenizers import Tokenizer
21
+ from tokenizers.models import WordLevel
22
+ from tokenizers.trainers import WordLevelTrainer
23
+ from tokenizers.pre_tokenizers import Whitespace
24
+
25
+ import torchmetrics
26
+ import wandb
27
+ import accelerate
28
+ from torch.utils.tensorboard import SummaryWriter
29
+ from safetensors.torch import load_model, save_model
30
+ from accelerate import Accelerator
31
+ from transformers import GPT2TokenizerFast
32
+ import threading
33
+
34
+
35
+ def greedy_decode(model, source, source_mask, tokenizer_tgt, max_len, device):
36
+ sos_idx = tokenizer_tgt.convert_tokens_to_ids('[SOS]')
37
+ eos_idx = tokenizer_tgt.convert_tokens_to_ids('[EOS]')
38
+
39
+ # Precompute the encoder output and reuse it for every step
40
+ encoder_output = model.module.encode(source, None)
41
+
42
+ # Initialize the decoder input with the sos token
43
+ decoder_input = torch.empty(1, 1).fill_(sos_idx).long().to(device)
44
+ while True:
45
+ if decoder_input.size(1) == max_len:
46
+ break
47
+
48
+ # build mask for target
49
+ decoder_mask = causal_mask(decoder_input.size(1)).long().to(device)
50
+
51
+
52
+ # calculate output
53
+ out = model.module.decode(encoder_output, source_mask, decoder_input, decoder_mask)
54
+ # print(f'out: {out.shape}')
55
+
56
+ # Get next token probabilities with temperature applied
57
+ logits = model.module.project(out[:, -1])
58
+ probabilities = F.softmax(logits, dim=-1)
59
+
60
+ # Greedily select the next word
61
+ next_word = torch.argmax(probabilities, dim=1)
62
+
63
+ # Append next word
64
+ decoder_input = torch.cat([decoder_input, next_word.unsqueeze(0)], dim=1)
65
+ # # get next token
66
+ # prob = model.project(out[:, -1])
67
+ # _, next_word = torch.max(prob, dim=1)
68
+ # # print(f'prob: {prob.shape}')
69
+ # decoder_input = torch.cat(
70
+ # [decoder_input, torch.empty(1, 1).long().fill_(next_word.item()).to(device)], dim=1
71
+ # )
72
+
73
+ if next_word.item() == eos_idx:
74
+ break
75
+
76
+ return decoder_input.squeeze(0)
77
+
78
+
79
+ def run_validation(model, validation_ds,tokenizer_tgt, max_len, device, print_msg, global_step, num_examples=3):
80
+ model.eval()
81
+ count = 0
82
+
83
+ source_texts = []
84
+ expected = []
85
+ predicted = []
86
+
87
+ try:
88
+ # get the console window width
89
+ with os.popen('stty size', 'r') as console:
90
+ _, console_width = console.read().split()
91
+ console_width = int(console_width)+_
92
+ except:
93
+ # If we can't get the console width, use 80 as default
94
+ console_width = 80
95
+
96
+ with torch.no_grad():
97
+ for batch in validation_ds:
98
+ count += 1
99
+ encoder_input = batch["encoder_input"].to(device) # (b, seq_len)
100
+ encoder_mask = batch["encoder_mask"].to(device) # (b, 1, 1, seq_len)
101
+
102
+ # check that the batch size is 1
103
+ assert encoder_input.size(
104
+ 0) == 1, "Batch size must be 1 for validation"
105
+
106
+ model_out = greedy_decode(model, encoder_input, None, tokenizer_tgt, max_len, device)
107
+
108
+ # source_text = batch["src_text"][0]
109
+ target_text = batch["tgt_text"][0]
110
+
111
+ model_out_text = tokenizer_tgt.decode(model_out.detach().cpu().numpy())
112
+
113
+ # source_texts.append(source_text)
114
+ expected.append(target_text)
115
+ predicted.append(model_out_text)
116
+
117
+ # Print the source, target and model output
118
+ print_msg('-'*console_width)
119
+ # print_msg(f"{f'SOURCE: ':>12}{source_text}")
120
+ print_msg(f"{f'TARGET: ':>12}{target_text}")
121
+ print_msg(f"{f'PREDICTED: ':>12}{model_out_text}")
122
+
123
+ if count == num_examples:
124
+ print_msg('-'*console_width)
125
+ break
126
+
127
+
128
+ # if writer:
129
+ # # Evaluate the character error rate
130
+ # # Compute the char error rate
131
+ # metric = torchmetrics.CharErrorRate()
132
+ # cer = metric(predicted, expected)
133
+ # writer.add_scalar('validation cer', cer, global_step)
134
+ # writer.flush()
135
+
136
+ # # Compute the word error rate
137
+ # metric = torchmetrics.WordErrorRate()
138
+ # wer = metric(predicted, expected)
139
+ # writer.add_scalar('validation wer', wer, global_step)
140
+ # writer.flush()
141
+
142
+ # # Compute the BLEU metric
143
+ # metric = torchmetrics.BLEUScore()
144
+ # bleu = metric(predicted, expected)
145
+ # writer.add_scalar('validation BLEU', bleu, global_step)
146
+ # writer.flush()
147
+
148
+ def get_all_sentences(ds):
149
+ for item in ds:
150
+ yield item['text']
151
+ def batch_iterator(data):
152
+ for i in range(0, len(data)):
153
+ yield data[i]['text']
154
+
155
+ # Assuming batch_iterator is a function that yields batches
156
+ def tqdm_batch_iterator(data, *args, **kwargs):
157
+ for batch in tqdm(batch_iterator(data, *args, **kwargs), total=len(data)):
158
+ yield batch
159
+
160
+ def get_or_build_tokenizer(config, ds):
161
+ tokenizer = GPT2TokenizerFast.from_pretrained("openai-community/gpt2", unk_token ='[UNK]', bos_token = '[SOS]', eos_token = '[EOS]' , pad_token = '[PAD]')
162
+ return tokenizer
163
+ # tokenizer_path = Path(config['tokenizer_file'])
164
+ # if not Path.exists(tokenizer_path):
165
+ # # Most code taken from: https://huggingface.co/docs/tokenizers/quicktour
166
+ # tokenizer = Tokenizer(WordLevel(unk_token="[UNK]"))
167
+ # tokenizer.pre_tokenizer = Whitespace()
168
+ # trainer = WordLevelTrainer(special_tokens=["[UNK]", "[PAD]", "[SOS]", "[EOS]"], min_frequency=2)
169
+ # tokenizer.train_from_iterator(get_all_sentences(ds), trainer=trainer)
170
+ # tokenizer.save(str(tokenizer_path))
171
+ # else:
172
+ # tokenizer = Tokenizer.from_file(str(tokenizer_path))
173
+ # return tokenizer
174
+
175
+ def get_ds(config):
176
+ # It only has the train split, so we divide it overselves
177
+ # ds_raw = load_dataset("HausaNLP/HausaVG", split='train+validation+test+challenge_test')
178
+ train_ds_raw = load_dataset("MMInstruction/M3IT", 'coco', split ='train')
179
+
180
+ val_ds_raw = load_dataset("MMInstruction/M3IT", 'coco', split ='validation[:2%]')
181
+
182
+ # ds_raw = load_dataset('opus_books', f"{config['lang_src']}-{config['lang_tgt']}", split='train')
183
+
184
+ # Build tokenizers
185
+
186
+ tokenizer_tgt = get_or_build_tokenizer(config, train_ds_raw,)
187
+ seed = 20 # You can choose any integer as your seed
188
+ torch.manual_seed(seed)
189
+ # # Keep 90% for training, 10% for validation
190
+ # train_ds_size = int(0.9 * len(ds_raw))
191
+ # val_ds_size = len(ds_raw) - train_ds_size
192
+ # train_ds_raw, val_ds_raw = random_split(ds_raw, [train_ds_size, val_ds_size])
193
+
194
+ train_ds = BilingualDataset(train_ds_raw, tokenizer_tgt, config['seq_len'])
195
+ val_ds = BilingualDataset(val_ds_raw, tokenizer_tgt, config['seq_len'])
196
+
197
+
198
+ train_dataloader = DataLoader(train_ds,batch_size=config['batch_size'], shuffle=True )
199
+
200
+ val_dataloader = DataLoader(val_ds, batch_size=1,shuffle=True )
201
+
202
+ return train_dataloader, val_dataloader, tokenizer_tgt
203
+
204
+ def get_model(config, vocab_tgt_len):
205
+ model = build_transformer(vocab_tgt_len, config['seq_len'], d_model=config['d_model'])
206
+ return model
207
+
208
+ def train_model(config):
209
+
210
+ accelerator = Accelerator()
211
+
212
+
213
+ print()
214
+ wandb.login(key = 'c20a1022142595d7d1324fdc53b3ccb34c0ded22')
215
+ wandb.init(project="Vision", name=config['project_name'])
216
+
217
+ # Initialize WandB configuration
218
+ wandb.config.epochs = config['num_epochs']
219
+ wandb.config.batch_size = config['batch_size']
220
+ wandb.config.learning_rate = config['lr']
221
+ # Define the devic
222
+ # Define the device
223
+ device = accelerator.device
224
+ # device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
225
+ print("Using device:", device)
226
+
227
+ # Make sure the weights folder exists
228
+ Path(config['model_folder']).mkdir(parents=True, exist_ok=True)
229
+
230
+ train_dataloader, val_dataloader, tokenizer_tgt = get_ds(config)
231
+ model = get_model(config, len(tokenizer_tgt)).to(device)
232
+
233
+
234
+ optimizer = torch.optim.Adam(model.parameters(), lr=config['lr'], betas=(0.9, 0.98),eps=1e-9)
235
+
236
+ model, optimizer, train_dataloader, val_dataloader = accelerator.prepare(
237
+ model, optimizer, train_dataloader, val_dataloader
238
+ )
239
+
240
+ # If the user specified a model to preload before training, load it
241
+ initial_epoch = 0
242
+ global_step = 0
243
+
244
+ def save_models():
245
+ accelerator.save_state(output_dir=f'/kaggle/working/weights/tmodel_00')
246
+ print(f'saving global step {global_step}')
247
+
248
+ if config['preload']:
249
+ model_filename = get_weights_file_path(config, config['preload'])
250
+ print(f'Preloading model {model_filename}')
251
+ accelerator.load_state(model_filename)
252
+ initial_epoch = 4
253
+
254
+ # state = torch.load(model_filename)
255
+ # model.load_state_dict(state['model_state_dict'])
256
+ # initial_epoch = state['epoch'] + 1
257
+ # optimizer.load_state_dict(state['optimizer_state_dict'])
258
+ # global_step = state['global_step']
259
+
260
+ loss_fn = nn.CrossEntropyLoss(ignore_index=tokenizer_tgt.convert_tokens_to_ids('[PAD]'), label_smoothing=0.1).to(device)
261
+
262
+ for epoch in range(initial_epoch, config['num_epochs']):
263
+
264
+ # timer = threading.Timer(5*60, save_models)
265
+ # timer.start()
266
+
267
+ model.train()
268
+ batch_iterator = tqdm(train_dataloader, desc=f"Processing Epoch {epoch:02d}")
269
+
270
+ for batch in batch_iterator:
271
+
272
+ encoder_input = batch["encoder_input"].to(device) # (b, seq_len)
273
+ decoder_input = batch["decoder_input"].to(device) # (B, seq_len)
274
+ encoder_mask = batch["encoder_mask"].to(device) # (B, 1, 1, seq_len)
275
+ decoder_mask = batch["decoder_mask"].to(device) # (B, 1, seq_len, seq_len)
276
+
277
+ # Run the tensors through the encoder, decoder and the projection layer
278
+ encoder_output = model.module.encode(encoder_input, None) # (B, seq_len, d_model)
279
+ decoder_output = model.module.decode(encoder_output, None, decoder_input, decoder_mask) # (B, seq_len, d_model)
280
+ proj_output = model.module.project(decoder_output)
281
+
282
+ # (B, seq_len, vocab_size)
283
+
284
+ # Compare the output with the label
285
+ label = batch["label"].to(device) # (B, seq_len)
286
+
287
+ # Compute the loss using a simple cross entropy
288
+ loss = loss_fn(proj_output.view(-1, len(tokenizer_tgt)), label.view(-1))
289
+ batch_iterator.set_postfix({"loss": f"{loss.item():6.3f}"})
290
+
291
+ # Log the loss
292
+ wandb.log({"Training Loss": loss.item(), "Global Step": global_step})
293
+
294
+ # # Backpropagate the loss
295
+ # loss.backward()
296
+ accelerator.backward(loss)
297
+
298
+ # Update the weights
299
+ optimizer.step()
300
+ optimizer.zero_grad(set_to_none=True)
301
+
302
+ global_step += 1
303
+ # if global_step == 20000 or global_step == 25000:
304
+ # print(f'saved state at {global_step}')
305
+ # accelerator.save_state(output_dir=f'/kaggle/working/weights/tmodel_{epoch:02d}')
306
+ if global_step == 1000 or global_step == 5000 or global_step == 10000 or global_step == 15000 or global_step == 20000 or global_step == 30000:
307
+ run_validation(model, val_dataloader, tokenizer_tgt, config['seq_len'], device, lambda msg: batch_iterator.write(msg), global_step)
308
+ model.train()
309
+
310
+
311
+ # # Run validation at the end of every epoch
312
+ # Save the model at the end of every epoch
313
+ model_filename = get_weights_file_path(config, f"{epoch:02d}")
314
+ # torch.save({
315
+ # 'epoch': epoch,
316
+ # 'model_state_dict': model.state_dict(),
317
+ # 'optimizer_state_dict': optimizer.state_dict(),
318
+ # 'global_step': global_step
319
+ # }, model_filename)
320
+ # accelerator.save_model(model, model_filename)
321
+
322
+ accelerator.save_state(output_dir=f'/kaggle/working/weights/tmodel_{epoch:02d}')
323
+ # run_validation(model, val_dataloader, tokenizer_src, tokenizer_tgt, config['seq_len'], device, lambda msg: batch_iterator.write(msg), global_step, writer)
324
+ model.eval()
325
+ eval_loss = 0.0
326
+
327
+ #accelerate
328
+ accurate = 0
329
+ num_elems = 0
330
+ # batch_iterator = tqdm(v_dataloader, desc=f"Processing Epoch {epoch:02d}")
331
+ with torch.no_grad():
332
+ batch_itere = tqdm(val_dataloader, desc=f"Processing loss")
333
+ for batch in batch_itere:
334
+
335
+
336
+ encoder_input = batch['encoder_input'].to(device) # (b, seq_len)
337
+ decoder_input = batch['decoder_input'].to(device) # (B, seq_len)
338
+ encoder_mask = batch['encoder_mask'].to(device) # (B, 1, 1, seq_len)
339
+ decoder_mask = batch['decoder_mask'].to(device) # (B, 1, seq_len, seq_len)
340
+
341
+ # Run the tensors through the encoder, decoder and the projection layer
342
+
343
+ encoder_output = model.module.encode(encoder_input, None) # (B, seq_len, d_model)
344
+ decoder_output = model.module.decode(encoder_output, None, decoder_input, decoder_mask)# (B, seq_len, d_model)
345
+ proj_output = model.module.project(decoder_output)
346
+
347
+ # (B, seq_len, vocab_size)
348
+
349
+ # Compare the output with the label
350
+ # label = batch['label'].to(device) # (B, seq_len)
351
+ proj_output, label = accelerator.gather_for_metrics((
352
+ proj_output, batch["label"]
353
+ ))
354
+
355
+ # Compute the loss using a simple cross entropy
356
+ ls = loss_fn(proj_output.view(-1, len(tokenizer_tgt)), label.view(-1))
357
+ batch_itere.set_postfix({"loss": f"{ls.item():6.3f}"})
358
+ eval_loss += ls
359
+ # loss_fn(proj_output.view(-1, tokenizer_tgt.get_vocab_size()), label.view(-1))
360
+
361
+
362
+ avg_val_loss = eval_loss / len(val_dataloader)
363
+ accelerator.print(f"Epoch {epoch},Validation Loss: {avg_val_loss})Validation Loss: {avg_val_loss}")
364
+ # print(f'Epoch {epoch},Validation Loss: {avg_val_loss.item()}')
365
+ wandb.log({"Validation Loss": avg_val_loss.item(), "Global Step": global_step})
366
+
367
+
368
+ run_validation(model, val_dataloader, tokenizer_tgt, config['seq_len'], device, lambda msg: batch_iterator.write(msg), global_step)
369
+
370
+
371
+ if __name__ == '__main__':
372
+ warnings.filterwarnings("ignore")
373
+ config = get_config()
374
+ train_model(config)