mojtaba-nafez commited on
Commit
2fa2727
β€’
1 Parent(s): 85615cd

add initial files to deploy

Browse files
app.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from models import PoemTextModel
2
+ from inference import predict_poems_from_text
3
+ from utils import get_poem_embeddings
4
+ import config as CFG
5
+ import json
6
+ import gradio as gr
7
+
8
+ def greet_user(name):
9
+ return "Hello " + name + " Welcome to Gradio!😎"
10
+
11
+ if __name__ == "__main__":
12
+ model = PoemTextModel(poem_encoder_pretrained=True, text_encoder_pretrained=True).to(CFG.device)
13
+ model.eval()
14
+ # Inference: Output some example predictions and write them in a file
15
+ with open(CFG.dataset_path, encoding="utf-8") as f:
16
+ dataset = json.load(f)
17
+
18
+ def gradio_make_predictions(text):
19
+ beyts = predict_poems_from_text(model, poem_embeddings, text, [data['beyt'] for data in dataset], n=10)
20
+ return "\n".join(beyts)
21
+
22
+ CFG.batch_size = 512
23
+ model, poem_embeddings = get_poem_embeddings(dataset, model)
24
+ # print(poem_embeddings[0])
25
+ # with open('poem_embeddings.json'.format(CFG.poem_encoder_model, CFG.text_encoder_model),'w', encoding="utf-8") as f:
26
+ # f.write(json.dumps(poem_embeddings, indent= 4))
27
+
28
+ text_input = gr.Textbox(label = "Enter the text to find poem beyts for")
29
+ output = gr.Textbox()
30
+
31
+ app = gr.Interface(fn = gradio_make_predictions, inputs=text_input, outputs=output)
32
+ app.launch()
config.py ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import AutoTokenizer, AutoModel, AutoConfig
3
+ from transformers import BertTokenizer, BertModel, BertConfig, BertTokenizerFast
4
+ from transformers import XLMRobertaModel, XLMRobertaConfig
5
+ import os
6
+
7
+ """
8
+ Configurations
9
+ """
10
+ file_dirname = os.path.dirname(__file__) #in case it is needed for relative paths
11
+ dataset_path = os.path.join(file_dirname, "../data/Dataset-Merged.json") # dataset path for PoemTextModel training, validation and test
12
+ image_path = "" # path to append to the image filenames of datasets used for CLIPModel training
13
+ random_seed = 3 # the seed used to shuffle dataset with
14
+
15
+ # what percentage of dataset will be used for each set?
16
+ train_propotion = 0.85
17
+ val_propotion = 0.05
18
+ # The remaining will be used as the test set
19
+
20
+ batch_size = 128
21
+ num_workers = 0 # parameter of torch Dataloader
22
+ lr = 1e-3 # learning rate
23
+ weight_decay = 1e-3
24
+ patience = 2 # patience parameter for lr scheduler
25
+ factor = 0.5 # factor parameter for lr scheduler
26
+ epochs = 60
27
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
28
+
29
+ # Pretrained hugging face models chosen by poem_encoder_model
30
+ poem_encoder_dict = {
31
+ "Bert":{
32
+ "poem_encoder_pretrained_name": 'mitra-mir/BERT-Persian-Poetry',
33
+ },
34
+ "ALBERT":{
35
+ "poem_encoder_pretrained_name": 'mitra-mir/ALBERT-Persian-Poetry',
36
+ },
37
+ "ParsBERT":{
38
+ "poem_encoder_pretrained_name": 'HooshvareLab/bert-base-parsbert-uncased',
39
+ },
40
+ }
41
+
42
+ poem_encoder_model = "ParsBERT" ### Important! The base model for poem encoder (one of "Bert", "ALBERT" and "ParsBERT")
43
+ # keep this an empty string if you want to use the pretrained weights from
44
+ # huggingface (poem_encoder_dict[poem_encoder_model])/a fresh model.
45
+ # else give the path to encoder
46
+ poem_encoder_load_path = ""
47
+ # path to save encoder to
48
+ poem_encoder_save_path = "{}-poem-encoder".format(poem_encoder_model)
49
+
50
+ if poem_encoder_load_path:
51
+ poem_encoder_pretrained_name = poem_encoder_load_path
52
+ poem_tokenizer = poem_encoder_load_path
53
+ else:
54
+ poem_encoder_pretrained_name = poem_encoder_dict[poem_encoder_model]['poem_encoder_pretrained_name']
55
+ poem_tokenizer = poem_encoder_dict[poem_encoder_model]['poem_encoder_pretrained_name']
56
+
57
+ poem_embedding = 768 # embedding dim of poem encoder's output (for one token)
58
+ poems_max_length = 64 # max_length parameter when padding/truncating poems using poem tokenizer
59
+ # keep this an empty string if you want to use a freshly initialized projection module. else give the path to projection model
60
+ poem_projection_load_path = os.path.join(file_dirname, "projections/{}_best_poem_projection.pt".format(poem_encoder_model))
61
+ # path to save projection to
62
+ poem_projection_save_path = "{}_best_poem_projection.pt".format(poem_encoder_model)
63
+ poem_encoder_trainable = False # if set to false, this encoder's frozen and its weights won't be saved at all.
64
+
65
+ # Pretrained hugging face models chosen by text_encoder_model
66
+ text_encoder_dict = {
67
+ "M-Bert":{
68
+ "text_encoder_pretrained_name": 'bert-base-multilingual-cased',
69
+ },
70
+ "XLM-RoBERTa":{
71
+ "text_encoder_pretrained_name": 'xlm-roberta-base',
72
+ },
73
+ "LaBSE":{
74
+ "text_encoder_pretrained_name": 'setu4993/LaBSE',
75
+ }
76
+ }
77
+ text_encoder_model = 'LaBSE' ### Important! The base model for text encoder (one of "M-Bert", "XLM-RoBERTa" and "LaBSE")
78
+ # keep this an empty string if you want to use the pretrained weights from huggingface/a fresh model. else give the path to encoder
79
+ text_encoder_load_path = ""
80
+ # path to save encoder to
81
+ text_encoder_save_path = "{}-text-encoder".format(text_encoder_model)
82
+
83
+ if text_encoder_load_path:
84
+ text_encoder_pretrained_name = text_encoder_load_path
85
+ text_tokenizer = text_encoder_load_path
86
+ else:
87
+ text_encoder_pretrained_name = text_encoder_dict[text_encoder_model]["text_encoder_pretrained_name"]
88
+ text_tokenizer = text_encoder_dict[text_encoder_model]["text_encoder_pretrained_name"]
89
+
90
+ text_embedding = 768 # embedding dim of text encoder's output (for one token)
91
+ text_max_length = 200 # max_length parameter when padding/truncating text using text tokenizer
92
+ # keep this an empty string if you want to use a freshly initialized projection module. else give the path to projection model
93
+ text_projection_load_path = os.path.join(file_dirname, "projections/{}_best_text_projection.pt".format(text_encoder_model))
94
+ # path to save peojection to
95
+ text_projection_save_path = "{}_best_text_projection.pt".format(text_encoder_model)
96
+ text_encoder_trainable = False # if set to false, this encoder's frozen and its weights won't be saved at all.
97
+
98
+
99
+ image_encoder_model = 'resnet50' # image model name to load via timm library
100
+ # keep this an empty string if you want to use the pretrained weights from huggingface/a fresh model. else give the path to encoder
101
+ image_encoder_weights_load_path = ""
102
+ # path to save encoder weights to
103
+ image_encoder_weights_save_path = "{}_best_image_encoder.pt".format(image_encoder_model)
104
+ image_embedding = 2048 # embedding dim of image encoder's output (for one token)
105
+ # keep this an empty string if you want to use a freshly initialized projection module. else give the path to projection model
106
+ image_projection_load_path = ""
107
+ # path to save projection to
108
+ image_projection_save_path = "{}_best_image_projection.pt".format(image_encoder_model)
109
+ image_encoder_trainable = False # if set to false, this encoder's frozen and its weights won't be saved at all.
110
+
111
+ # classes of Tokenizer, Model and Config to use for each text/poem encoder model
112
+ tokenizers = {"ALBERT": AutoTokenizer, "M-Bert": BertTokenizer, "XLM-RoBERTa": AutoTokenizer, "ParsBERT":AutoTokenizer, "Bert":AutoTokenizer, "LaBSE": BertTokenizerFast}
113
+ encoders = {"ALBERT": AutoModel, "M-Bert": BertModel, "XLM-RoBERTa":XLMRobertaModel, "ParsBERT": AutoModel, "Bert":AutoModel, "LaBSE": BertModel}
114
+ configs = {"ALBERT": AutoConfig, "M-Bert": BertConfig, "XLM-RoBERTa": XLMRobertaConfig, "ParsBERT": AutoConfig, "Bert":AutoConfig, "LaBSE": BertConfig}
115
+
116
+
117
+ temperature = 1.0 # temperature parameter for scaling dot similarities
118
+
119
+ # image size
120
+ size = 224
121
+
122
+ # for projection head; used for poem, text and image encoders
123
+ projection_dim = 1024 # projection embedding dim (output of models dim)
124
+ dropout = 0.1 # fraction of the output of fc layer in projection head to be zeroed.
data/Dataset-Merged.json ADDED
The diff for this file is too large to render. See raw diff
 
data/test_dataset.json ADDED
The diff for this file is too large to render. See raw diff
 
data/train_dataset.json ADDED
The diff for this file is too large to render. See raw diff
 
data/val_dataset.json ADDED
The diff for this file is too large to render. See raw diff
 
datasets.py ADDED
@@ -0,0 +1,214 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import cv2
3
+ import torch
4
+ import albumentations as A
5
+ import config as CFG
6
+
7
+
8
+ class PoemTextDataset(torch.utils.data.Dataset):
9
+ """
10
+ torch Dataset for PoemTextModel.
11
+ ...
12
+ Attributes:
13
+ -----------
14
+ dataset_dict : list of dict
15
+ dataset containing poem-text pair with ids
16
+ encoded_poems : dict
17
+ output of tokenizer for beyts found in dataset_dict. max_length spedified in configs.
18
+ padding and truncation set to True to be truncated or padded to max length.
19
+ encoded_texts : dict
20
+ output of tokenizer for texts found in dataset_dict. max_length spedified in configs.
21
+ padding and truncation set to True to be truncated or padded to max length.
22
+
23
+ Methods:
24
+ --------
25
+ __get_item__(idx)
26
+ returns item with index idx.
27
+ __len__()
28
+ represents length of dataset
29
+ """
30
+ def __init__(self, dataset_dict):
31
+ """
32
+ Init class, save dataset_dict and calculate output of tokenizers for each text and poem using their corresponding tokenizers.
33
+ The tokenizers are chosen based on configs.
34
+
35
+ Parameters:
36
+ -----------
37
+ dataset_dict: list of dict
38
+ a list containing dictionaries which have "beyt", "text" and "id" keys.
39
+ """
40
+ self.dataset_dict = dataset_dict
41
+ poem_tokenizer = CFG.tokenizers[CFG.poem_encoder_model].from_pretrained(CFG.poem_tokenizer)
42
+ text_tokenizer = CFG.tokenizers[CFG.text_encoder_model].from_pretrained(CFG.text_tokenizer)
43
+ self.encoded_poems = poem_tokenizer(
44
+ [item['beyt'] for item in dataset_dict], padding=True, truncation=True, max_length=CFG.poems_max_length
45
+ )
46
+ self.encoded_texts = text_tokenizer(
47
+ [item['text'] for item in dataset_dict], padding=True, truncation=True, max_length=CFG.text_max_length
48
+ )
49
+
50
+ def __getitem__(self, idx):
51
+ """
52
+ returns a dict having data with index idx. the dict is used as an input to the PoemTextModel.
53
+
54
+ Parameters:
55
+ -----------
56
+ idx: int
57
+ index of the data to get
58
+
59
+ Returns:
60
+ --------
61
+ item: dict
62
+ a dict having tokenizers' output for poem and text, and id of the data with index idx
63
+ """
64
+ item = {}
65
+ item["beyt"] = {
66
+ key: torch.tensor(values[idx])
67
+ for key, values in self.encoded_poems.items()
68
+ }
69
+
70
+ item["text"] = {
71
+ key: torch.tensor(values[idx])
72
+ for key, values in self.encoded_texts.items()
73
+ }
74
+ item['id'] = self.dataset_dict[idx]['id']
75
+
76
+ return item
77
+
78
+
79
+ def __len__(self):
80
+ """
81
+ returns the length of the dataset
82
+
83
+ Returns:
84
+ --------
85
+ length: int
86
+ length using the length of dataset_dict we saved in class
87
+ """
88
+ return len(self.dataset_dict)
89
+
90
+
91
+ class CLIPDataset(torch.utils.data.Dataset):
92
+ """
93
+ torch Dataset for CLIPModel.
94
+ ...
95
+ Attributes:
96
+ -----------
97
+ dataset_dict : list of dict
98
+ dataset containing poem-image or text-image pair with ids
99
+ encoded : dict
100
+ output of tokenizer for beyts/texts found in dataset_dict. max_length spedified in configs.
101
+ padding and truncation set to True to be truncated or padded to max length.
102
+ transforms: albumentations.BasicTransform
103
+ transforms to apply to the images
104
+
105
+ Methods:
106
+ --------
107
+ __get_item__(idx)
108
+ returns item with index idx.
109
+ __len__()
110
+ represents length of dataset
111
+ """
112
+ def __init__(self, dataset_dict, transforms, is_image_poem_pair=True):
113
+ """
114
+ Init class, save dataset_dict and transforms and calculate output of tokenizers for each text and poem using their corresponding tokenizers.
115
+ The tokenizers are chosen based on configs.
116
+
117
+ Parameters:
118
+ -----------
119
+ dataset_dict: list of dict
120
+ a list containing dictionaries which have "beyt", "text" and "id" keys.
121
+ transforms: albumentations.BasicTransform
122
+ transforms to apply to the images
123
+ is_image_poem_pair: Bool, optional
124
+ if set to False, dataset has text-image pairs and must use the corresponding text tokenizer.
125
+ else has poem-images pairs and uses the poem tokenizer.
126
+ """
127
+ self.dataset_dict = dataset_dict
128
+ # using the poem tokenizer to encode poems or text tokenizer to encode text (based on configs).
129
+ if is_image_poem_pair:
130
+ poem_tokenizer = CFG.tokenizers[CFG.poem_encoder_model].from_pretrained(CFG.poem_tokenizer)
131
+ self.encoded = poem_tokenizer(
132
+ [item['beyt'] for item in dataset_dict], padding=True, truncation=True, max_length=CFG.poems_max_length
133
+ )
134
+ else:
135
+ text_tokenizer = CFG.tokenizers[CFG.text_encoder_model].from_pretrained(CFG.text_tokenizer)
136
+ self.encoded = text_tokenizer(
137
+ [item['text'] for item in dataset_dict], padding=True, truncation=True, max_length=CFG.text_max_length
138
+ )
139
+ self.transforms = transforms
140
+
141
+ def __getitem__(self, idx):
142
+ """
143
+ returns a dict having data with index idx. the dict is used as an input to the CLIPModel.
144
+
145
+ Parameters:
146
+ -----------
147
+ idx: int
148
+ index of the data to get
149
+
150
+ Returns:
151
+ --------
152
+ item: dict
153
+ a dict having tokenizers' output for poem and text, and id of the data with index idx
154
+ """
155
+ item = {}
156
+ # getting text from encoded texts
157
+ item["text"] = {
158
+ key: torch.tensor(values[idx])
159
+ for key, values in self.encoded.items()
160
+ }
161
+
162
+ # opening the image
163
+ image = cv2.imread(f"{CFG.image_path}{self.dataset_dict[idx]['image']}")
164
+ # converting BGR to RGB for transforms
165
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
166
+ # apply transforms
167
+ image = self.transforms(image=image)['image']
168
+ # permute dims of image
169
+ item['image'] = torch.tensor(image).permute(2, 0, 1).float()
170
+
171
+ return item
172
+
173
+
174
+ def __len__(self):
175
+ """
176
+ returns the length of the dataset
177
+
178
+ Returns:
179
+ --------
180
+ length: int
181
+ length using the length of dataset_dict we saved in class
182
+ """
183
+ return len(self.dataset_dict)
184
+
185
+
186
+
187
+ def get_transforms(mode="train"):
188
+ """
189
+ returns transforms to use on image based on mode
190
+
191
+ Parameters:
192
+ -----------
193
+ mode: str, optional
194
+ to distinguish between train and val/test transforms (here they are the same!)
195
+
196
+ Returns:
197
+ --------
198
+ item: dict
199
+ a dict having tokenizers' output for poem and text, and id of the data with index idx
200
+ """
201
+ if mode == "train":
202
+ return A.Compose(
203
+ [
204
+ A.Resize(CFG.size, CFG.size, always_apply=True), # resizing image to CFG.size
205
+ A.Normalize(max_pixel_value=255.0, always_apply=True), # normalizing image values
206
+ ]
207
+ )
208
+ else:
209
+ return A.Compose(
210
+ [
211
+ A.Resize(CFG.size, CFG.size, always_apply=True), # resizing image to CFG.size
212
+ A.Normalize(max_pixel_value=255.0, always_apply=True), # normalizing image values
213
+ ]
214
+ )
inference.py ADDED
@@ -0,0 +1,147 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+ import torch
3
+ import cv2
4
+ import torch.nn.functional as F
5
+ import numpy as np
6
+ import config as CFG
7
+ from datasets import get_transforms
8
+
9
+ #for running this script as main
10
+ from utils import get_datasets, build_loaders
11
+ from models import PoemTextModel
12
+ from utils import get_poem_embeddings
13
+ import json
14
+ import os
15
+
16
+
17
+ def predict_poems_from_text(model, poem_embeddings, query, poems, text_tokenizer=None, n=10):
18
+ """
19
+ Returns n poems which are the most similar to a text query
20
+
21
+ Parameters:
22
+ -----------
23
+ model: PoemTextModel
24
+ model to compute text query's embeddings
25
+ poem_embeddings: sequence with shape (#poems, CFG.projection_dim)
26
+ poem embeddings to check similarity
27
+ query: str
28
+ text query
29
+ poems: list of str
30
+ poems corresponding to poem_embeddings
31
+ text_tokenizer: huggingface Tokenizer, optional
32
+ tokenizer to tokenize query with. if none, will instantiate a new text tokenizer using configs.
33
+ n: int, optional
34
+ number of poems to return
35
+
36
+ Returns:
37
+ --------
38
+ A list of n poem strings whose embeddings are the most similar to query text's embedding.
39
+
40
+ """
41
+ #Tokenizing and Encoding the query text
42
+ if not text_tokenizer:
43
+ text_tokenizer = CFG.tokenizers[CFG.text_encoder_model].from_pretrained(CFG.text_tokenizer)
44
+
45
+ encoded_query = text_tokenizer([query])
46
+ batch = {
47
+ key: torch.tensor(values).to(CFG.device)
48
+ for key, values in encoded_query.items()
49
+ }
50
+
51
+ # getting query text's embeddings
52
+ model.eval()
53
+ with torch.no_grad():
54
+ text_features = model.text_encoder(
55
+ input_ids= batch["input_ids"], attention_mask=batch["attention_mask"]
56
+ )
57
+ text_embeddings = model.text_projection(text_features)
58
+
59
+ # normalizing and computing dot similarity of poem and text embeddings
60
+ poem_embeddings_n = F.normalize(poem_embeddings, p=2, dim=-1)
61
+ text_embeddings_n = F.normalize(text_embeddings, p=2, dim=-1)
62
+
63
+ dot_similarity = text_embeddings_n @ poem_embeddings_n.T
64
+
65
+ # returning top n poems based on embedding similarity
66
+ _, indices = torch.topk(dot_similarity.squeeze(0), n)
67
+ return [poems[idx] for idx in indices]
68
+
69
+
70
+ def predict_poems_from_image(model, poem_embeddings, image_filename, poems, n=10):
71
+ """
72
+ Returns n poems which are the most similar to an image query
73
+
74
+ Parameters:
75
+ -----------
76
+ model: CLIPModel
77
+ model to compute image query's embeddings
78
+ poem_embeddings: sequence with shape (#poems, CFG.projection_dim)
79
+ poem embeddings to check similarity
80
+ image_filename: str
81
+ path and file name for the image query
82
+ poems: list of str
83
+ poems corresponding to poem_embeddings
84
+ n: int, optional
85
+ number of poems to return
86
+
87
+ Returns:
88
+ --------
89
+ A list of n poem strings whose embeddings are the most similar to image query's embedding.
90
+
91
+ """
92
+ # Reading, Processing and applying transforms to image (all explained in datasets.py)
93
+ image = cv2.imread(f"{image_filename}")
94
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
95
+ image = get_transforms(mode="test")(image=image)['image']
96
+ image = torch.tensor(image).permute(2, 0, 1).float()
97
+
98
+ # getting image query's embeddings
99
+ model.eval()
100
+ with torch.no_grad():
101
+ image_features = model.image_encoder(torch.unsqueeze(image, 0).to(CFG.device))
102
+ image_embeddings = model.image_projection(image_features)
103
+
104
+ # normalizing and computing dot similarity of poem and text embeddings
105
+ poem_embeddings_n = F.normalize(poem_embeddings, p=2, dim=-1)
106
+ image_embeddings_n = F.normalize(image_embeddings, p=2, dim=-1)
107
+ dot_similarity = image_embeddings_n @ poem_embeddings_n.T
108
+
109
+ # returning top n poems based on embedding similarity
110
+ _, indices = torch.topk(dot_similarity.squeeze(0), n)
111
+ return [poems[idx] for idx in indices]
112
+
113
+ if __name__ == "__main__":
114
+ """
115
+ Creates a PoemTextModel based on configs, and outputs some examples of its prediction.
116
+ """
117
+ # get dataset from dataset_path (the same datasets as the train, val and test dataset files in the data directory is made)
118
+ train_dataset, val_dataset, test_dataset = get_datasets()
119
+
120
+ model = PoemTextModel(poem_encoder_pretrained=True, text_encoder_pretrained=True).to(CFG.device)
121
+ model.eval()
122
+ # Inference: Output some example predictions and write them in a file
123
+ print("_"*20)
124
+ print("Output Examples from test set")
125
+ model, poem_embeddings = get_poem_embeddings(test_dataset, model)
126
+ example = {}
127
+ for i, test_data in enumerate(test_dataset[:100]):
128
+ example[i] = {'Text': test_data["text"], 'True Beyt': test_data["beyt"], "Predicted Beyt":predict_poems_from_text(model, poem_embeddings, test_data["text"], [data['beyt'] for data in test_dataset], n=10)}
129
+ for i in range(10):
130
+ print("Text: ", example[i]['Text'])
131
+ print("True Beyt: ", example[i]['True Beyt'])
132
+ print("predicted Beyts: \n\t", "\n\t".join(example[i]["Predicted Beyt"]))
133
+ with open('example_output__{}_{}.json'.format(CFG.poem_encoder_model, CFG.text_encoder_model),'w', encoding="utf-8") as f:
134
+ f.write(json.dumps(example, ensure_ascii=False, indent= 4))
135
+
136
+ print("Preparing model for user input...")
137
+ with open(CFG.dataset_path, encoding="utf-8") as f:
138
+ dataset = json.load(f)
139
+
140
+ model, poem_embeddings = get_poem_embeddings(dataset, model)
141
+
142
+ while(True):
143
+ user_text = input("Enter a Text to find poem beyts for: ")
144
+ beyts = predict_poems_from_text(model, poem_embeddings, user_text, [data['beyt'] for data in dataset], n=10)
145
+ print("predicted Beyts: \n\t", "\n\t".join(beyts))
146
+ with open('{}_output__{}_{}.json'.format(user_text, CFG.poem_encoder_model, CFG.text_encoder_model),'a+', encoding="utf-8") as f:
147
+ f.write(json.dumps(beyts, ensure_ascii=False, indent= 4))
main.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from utils import get_datasets, build_loaders
2
+ from models import PoemTextModel
3
+ from train import train, test
4
+ from metrics import calc_metrics
5
+ from inference import predict_poems_from_text
6
+ from utils import get_poem_embeddings
7
+ import config as CFG
8
+ import json
9
+
10
+ def main():
11
+ """
12
+ Creates a PoemTextModel based on configs and trains, tests and outputs some examples of its prediction.
13
+ """
14
+ # get dataset from dataset_path (the same datasets as the train, val and test dataset files in the data directory is made)
15
+ train_dataset, val_dataset, test_dataset = get_datasets()
16
+
17
+ train_loader = build_loaders(train_dataset, mode="train")
18
+ valid_loader = build_loaders(val_dataset, mode="valid")
19
+
20
+ # train a PoemTextModel and write its loss history in a file
21
+ model = PoemTextModel(poem_encoder_pretrained=True, text_encoder_pretrained=True).to(CFG.device)
22
+ model, loss_history = train(model, train_loader, valid_loader)
23
+ with open('loss_history_{}_{}.json'.format(CFG.poem_encoder_model, CFG.text_encoder_model),'w', encoding="utf-8") as f:
24
+ f.write(json.dumps(loss_history, indent= 4))
25
+
26
+ # compute accuracy, mean rank and MRR using test set and write them in a file
27
+ model.eval()
28
+ print("Accuracy on test set: ", test(model, test_dataset))
29
+ metrics = calc_metrics(test_dataset, model)
30
+ print('mean rank: ', metrics["mean_rank"])
31
+ print('mean reciprocal rank (MRR)', metrics["mean_reciprocal_rank_(MRR)"])
32
+ with open('test_metrics_{}_{}.json'.format(CFG.poem_encoder_model, CFG.text_encoder_model),'w', encoding="utf-8") as f:
33
+ f.write(json.dumps(metrics, indent= 4))
34
+
35
+ # Inference: Output some example predictions and write them in a file
36
+ print("_"*20)
37
+ print("Output Examples from test set")
38
+ model, poem_embeddings = get_poem_embeddings(test_dataset, model)
39
+ example = {}
40
+ for i, test_data in enumerate(test_dataset[:100]):
41
+ example[i] = {'Text': test_data["text"], 'True Beyt': test_data["beyt"], "Predicted Beyt":predict_poems_from_text(model, poem_embeddings, test_data["text"], [data['beyt'] for data in test_dataset], n=10)}
42
+ for i in range(10):
43
+ print("Text: ", example[i]['Text'])
44
+ print("True Beyt: ", example[i]['True Beyt'])
45
+ print("predicted Beyts: \n\t", "\n\t".join(example[i]["Predicted Beyt"]))
46
+ with open('example_output__{}_{}.json'.format(CFG.poem_encoder_model, CFG.text_encoder_model),'w', encoding="utf-8") as f:
47
+ f.write(json.dumps(example, ensure_ascii=False, indent= 4))
48
+
49
+ if __name__ == "__main__":
50
+ main()
main_clip.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from utils import get_datasets, build_loaders
2
+ from models import PoemTextModel
3
+ from train import train, test
4
+ from metrics import calc_metrics
5
+ from inference import predict_poems_from_text
6
+ from utils import get_poem_embeddings
7
+ import config as CFG
8
+ import json
9
+
10
+ def main():
11
+ """
12
+ Creates a PoemTextModel based on configs and trains, tests and outputs some examples of its prediction.
13
+ """
14
+ train_or_not = input("Train a new CLIP model using text embeddings? (needs the sajjadayobi360/cc3mfav2 and adityajn105/flickr8k datasets to be downloaded)\n[Y/N]")
15
+ if train_or_not == 'Y':
16
+ # Please download sajjadayobi360/cc3mfav2 and adityajn105/flickr8k datasets from kaggle
17
+ # !kaggle datasets download -d sajjadayobi360/cc3mfav2
18
+ # !kaggle datasets download -d adityajn105/flickr8k
19
+ #.... TODO
20
+ clip_dataset_dict = []
21
+ # get dataset from dataset_path (the same datasets as the train, val and test dataset files in the data directory is made)
22
+ train_dataset, val_dataset, test_dataset = get_clip_datasets(clip_dataset_dict)
23
+
24
+ train_loader = build_image_loaders(train_dataset, mode="train")
25
+ valid_loader = build_image_loaders(val_dataset, mode="valid")
26
+
27
+ # train a PoemTextModel and write its loss history in a file
28
+ model = CLIPModel(image_encoder_pretrained=True,
29
+ text_encoder_pretrained=True,
30
+ text_projection_trainable=False,
31
+ is_image_poem_pair=False
32
+ ).to(CFG.device)
33
+ model, loss_history = train(model, train_loader, valid_loader)
34
+ with open('loss_history_{}_{}.json'.format(CFG.poem_encoder_model, CFG.text_encoder_model),'w', encoding="utf-8") as f:
35
+ f.write(json.dumps(loss_history, indent= 4))
36
+
37
+ # Inference: Get a filename and output predictions then write them in a file
38
+ print("_"*20)
39
+ print("INFERENCE PHASE")
40
+ model = CLIPModel(image_encoder_pretrained=True,
41
+ text_encoder_pretrained=True,
42
+ text_projection_trainable=False,
43
+ is_image_poem_pair=True
44
+ ).to(CFG.device)
45
+ model.eval()
46
+ with open(CFG.dataset_path, encoding="utf-8") as f:
47
+ dataset = json.load(f)
48
+
49
+ model, poem_embeddings = get_poem_embeddings(test_dataset, model)
50
+
51
+ while(True):
52
+ image_filename = input("Enter an image filename to predict poems for")
53
+ beyts = predict_poems_from_image(model, poem_embeddings, image_filename, [data['beyt'] for data in dataset], n=10)
54
+ print("predicted Beyts: \n\t", "\n\t".join(beyts))
55
+ with open('{}_output__{}_{}.json'.format(image_filename, CFG.poem_encoder_model, CFG.text_encoder_model),'a+', encoding="utf-8") as f:
56
+ f.write(json.dumps(beyts, ensure_ascii=False, indent= 4))
57
+
58
+ if __name__ == "__main__":
59
+ main()
metrics.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+ import numpy as np
3
+ import inference
4
+ from utils import get_poem_embeddings
5
+ import config as CFG
6
+
7
+ #for running this script as main
8
+ from utils import get_datasets, build_loaders
9
+ from models import PoemTextModel
10
+ from train import train, test
11
+ import json
12
+ import os
13
+
14
+ def calc_metrics(test_dataset, model):
15
+ """
16
+ compute ranks of the test_dataset (and mean rank and MRR)
17
+
18
+ Parameters:
19
+ -----------
20
+ test_dataset: list of dict
21
+ dataset containing text and poem beyts to compute metrics from
22
+ model: PoemTextModel
23
+ The PoemTextModel model to get poem embeddings from and predict poems for each text
24
+ """
25
+ # computing all poems embeddings once (to avoid computing them for each test text)
26
+ m , embedding = get_poem_embeddings(test_dataset, model)
27
+ # adding poems and texts
28
+ poems = []
29
+ meanings = []
30
+ for p in np.array(test_dataset):
31
+ poems.append(p['beyt'])
32
+ meanings.append(p['text'])
33
+ # instantiating a text tokenizer to encode texts
34
+ text_tokenizer = CFG.tokenizers[CFG.text_encoder_model].from_pretrained(CFG.text_tokenizer)
35
+ rank = []
36
+ for i, meaning in enumerate(meanings):
37
+ # predict most similar poem beyts for each text
38
+ sorted_pred = inference.predict_poems_from_text(model, embedding, meaning, poems, text_tokenizer, n=len(test_dataset))
39
+ # find index of this text's true beyt in the sorted predictions
40
+ idx = sorted_pred.index(poems[i])
41
+ rank.append(idx+1)
42
+ rank = np.array(rank)
43
+ metrics = {
44
+ "mean_rank": np.mean(rank),
45
+ "mean_reciprocal_rank_(MRR)":np.mean(np.reciprocal(rank.astype(float))),
46
+ "rank": rank.tolist()
47
+ }
48
+ return metrics
49
+
50
+ if __name__ == "__main__":
51
+ """
52
+ Creates a PoemTextModel based on configs, and computes its metrics.
53
+ """
54
+ # get dataset from dataset_path (the same datasets as the train, val and test dataset files in the data directory is made)
55
+ train_dataset, val_dataset, test_dataset = get_datasets()
56
+
57
+ model = PoemTextModel(poem_encoder_pretrained=True, text_encoder_pretrained=True).to(CFG.device)
58
+ model.eval()
59
+ # compute accuracy, mean rank and MRR using test set and write them in a file
60
+ print("Accuracy on test set: ", test(model, test_dataset))
61
+ metrics = calc_metrics(test_dataset, model)
62
+ print('mean rank: ', metrics["mean_rank"])
63
+ print('mean reciprocal rank (MRR)', metrics["mean_reciprocal_rank_(MRR)"])
64
+ with open('test_metrics_{}_{}.json'.format(CFG.poem_encoder_model, CFG.text_encoder_model),'w', encoding="utf-8") as f:
65
+ f.write(json.dumps(metrics, indent= 4))
models.py ADDED
@@ -0,0 +1,410 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ import torch.nn.functional as F
4
+
5
+ #FIX
6
+ import config as CFG
7
+ from modules import TextEncoder, ProjectionHead, ImageEncoder
8
+
9
+
10
+ class PoemTextModel(nn.Module):
11
+ """
12
+ Model predicting poem and text embeddings, and their similarities.
13
+ ...
14
+ Attributes:
15
+ -----------
16
+ poem_encoder : TextEncoder
17
+ encoder used for extracting poem embeddings
18
+ text_encoder : TextEncoder
19
+ encoder used for extracting text embeddings
20
+ poem_projection: ProjectionHead
21
+ projection head used for poem embeddings (projects poem encoder output to shared embedding space)
22
+ text_projection: ProjectionHead
23
+ projection head used for text embeddings (projects text encoder output to shared embedding space)
24
+ temperature: float
25
+ used to scale the dot similarities
26
+
27
+ Methods:
28
+ --------
29
+ forward(batch):
30
+ returns poem and text embeddings of batch
31
+ similarity_scores(batch):
32
+ computes dot similarities of a batch of text-poem pair
33
+ predict(batch):
34
+ predicts the most similar poem idx for each text (using previous methods)
35
+ calculate_loss(batch):
36
+ computes contrastive (cross entropy) loss for both poems and texts.
37
+ save_current():
38
+ saves current model's encoders (if trainable) and projection heads.
39
+ """
40
+ def __init__(
41
+ self,
42
+ poem_encoder_pretrained,
43
+ text_encoder_pretrained,
44
+ temperature=CFG.temperature,
45
+ poem_embedding=CFG.poem_embedding,
46
+ text_embedding=CFG.text_embedding,
47
+ ):
48
+ """
49
+ Initializes model's submodules
50
+ Parameters:
51
+ -----------
52
+ poem_encoder_pretrained: bool
53
+ whether or not to load a pretrained poem encoder.
54
+ text_encoder_pretrained: bool
55
+ whether or not to load a pretrained text encoder.
56
+ temperature: float, optional
57
+ used to scale the dot similarities
58
+ poem_embedding: int, optional
59
+ dim of poem encoder's encoding output before projection
60
+ text_embedding: int, optional
61
+ dim of text encoder's encoding output before projection
62
+ """
63
+ super().__init__()
64
+ self.poem_encoder = TextEncoder(CFG.poem_encoder_model, CFG.poem_encoder_pretrained_name, pretrained=poem_encoder_pretrained, trainable= CFG.poem_encoder_trainable)
65
+ self.text_encoder = TextEncoder(CFG.text_encoder_model, CFG.text_encoder_pretrained_name, pretrained=text_encoder_pretrained, trainable= CFG.text_encoder_trainable)
66
+
67
+ self.poem_projection = ProjectionHead(embedding_dim=poem_embedding)
68
+ if CFG.poem_projection_load_path: # if provided, load projection weights from this path
69
+ self.poem_projection.load_state_dict(torch.load(CFG.poem_projection_load_path, map_location=CFG.device))
70
+
71
+ self.text_projection = ProjectionHead(embedding_dim=text_embedding)
72
+ if CFG.text_projection_load_path: # if provided, load projection weights from this path
73
+ self.text_projection.load_state_dict(torch.load(CFG.text_projection_load_path, map_location=CFG.device))
74
+
75
+ self.temperature = temperature
76
+
77
+ def forward(self, batch):
78
+ """
79
+ returns poem and text embeddings of batch
80
+
81
+ Parameters:
82
+ -----------
83
+ batch: list of dict
84
+ input (containing poem-text pairs (encoded using the encoder's tokenizer) with keys 'beyt' and 'text')
85
+
86
+ Returns:
87
+ --------
88
+ poem and text embeddings of batch (each of shape (batch_size, projection_dim))
89
+ """
90
+ beyts, texts = batch["beyt"], batch["text"]
91
+ # Getting Beyt and Text Features
92
+ poem_features = self.poem_encoder(
93
+ input_ids=beyts["input_ids"], attention_mask=beyts["attention_mask"]
94
+ )
95
+ text_features = self.text_encoder(
96
+ input_ids=texts["input_ids"], attention_mask=texts["attention_mask"]
97
+ )
98
+ # Getting Beyt and Text Embeddings (with same dimension)
99
+ poem_embeddings = self.poem_projection(poem_features)
100
+ text_embeddings = self.text_projection(text_features)
101
+
102
+ return poem_embeddings, text_embeddings
103
+
104
+ def similarity_scores(self, batch):
105
+ """
106
+ computes dot similarities of a batch of text-poem pair
107
+
108
+ Parameters:
109
+ -----------
110
+ batch: list of dict
111
+ input (containing poem-text pairs (encoded using the encoder's tokenizer) with keys 'beyt' and 'text')
112
+
113
+ Returns:
114
+ --------
115
+ dot similarity of poem and text embeddings of batch (of shape (batch_size, batch_size))
116
+ """
117
+ # Getting Beyt and Text Embeddings (with same dimension)
118
+ poem_embeddings, text_embeddings = self.forward(batch)
119
+ # Normalizing embeddings
120
+ poem_embeddings_n = F.normalize(poem_embeddings, p=2, dim=-1)
121
+ text_embeddings_n = F.normalize(text_embeddings, p=2, dim=-1)
122
+ # Computing dot / cosine similarity of the normalized embeddings
123
+ dot_similarity = text_embeddings_n @ poem_embeddings_n.T
124
+ return dot_similarity # (batch_size, batch_size) first dim is texts, second dim is poems for each text
125
+
126
+ def predict(self, batch):
127
+ """
128
+ predicts the most similar poem (idx) for each text (using previous methods)
129
+
130
+ Parameters:
131
+ -----------
132
+ batch: list of dict
133
+ input (containing poem-text pairs (encoded using the encoder's tokenizer) with keys 'beyt' and 'text')
134
+
135
+ Returns:
136
+ --------
137
+ index of poem predicted for each text (of shape (batch_size))
138
+ """
139
+ dot_similarity = self.similarity_scores(batch)
140
+ # Getting argmax in first dimension of the dot-similarities to predict index of the most similar poem for each text
141
+ return torch.argmax(dot_similarity, dim=1)
142
+
143
+ def calculate_loss(self, poem_embeddings, text_embeddings):
144
+ """
145
+ computes contrastive (cross entropy) loss for both poems and texts.
146
+
147
+ Parameters:
148
+ -----------
149
+ poem_embeddings: of shape (batch_size, projection_dim)
150
+ output embeddings of poem projection head
151
+ text_embeddings: of shape (batch_size, projection_dim)
152
+ output embeddings of text projection head
153
+
154
+ Returns:
155
+ --------
156
+ average of the loss computed from inputs
157
+ """
158
+ # dot similarity of the embeddings scaled by temperature (logits)
159
+ logits = (text_embeddings @ poem_embeddings.T) / self.temperature
160
+ # computing targets for the cross entropy loss to compare with logits.
161
+ # each embedding's similarity is computed with itself and then added,
162
+ # scaled by the temperature parameter, and normalized into a probability distribution via a softmax
163
+ poems_similarity = poem_embeddings @ poem_embeddings.T
164
+ texts_similarity = text_embeddings @ text_embeddings.T
165
+ targets = F.softmax(
166
+ (poems_similarity + texts_similarity) / 2 * self.temperature, dim=-1
167
+ )
168
+ # taking cross entropy loss in both dimensions: once for texts and once for poems
169
+ texts_loss = cross_entropy(logits, targets, reduction='none')
170
+ poems_loss = cross_entropy(logits.T, targets.T, reduction='none')
171
+ loss = (poems_loss + texts_loss) / 2.0 # average of losses. shape: (batch_size)
172
+ return loss.mean()
173
+
174
+ def save_current(self):
175
+ """
176
+ saves current model's encoders (if trainable) and projection heads.
177
+ """
178
+ if CFG.text_encoder_trainable:
179
+ self.text_encoder.model.save_pretrained(CFG.text_encoder_save_path)
180
+ if CFG.poem_encoder_trainable:
181
+ self.poem_encoder.model.save_pretrained(CFG.poem_encoder_save_path)
182
+ torch.save(self.text_projection.state_dict(), CFG.text_projection_save_path)
183
+ torch.save(self.poem_projection.state_dict(), CFG.poem_projection_save_path)
184
+
185
+ class CLIPModel(nn.Module):
186
+ """
187
+ Model predicting poem/text and image embeddings, and their similarities.
188
+ ...
189
+ Attributes:
190
+ -----------
191
+ encoder : TextEncoder
192
+ encoder used for extracting poem/text embeddings
193
+ image_encoder : ImageEncoder
194
+ encoder used for extracting image embeddings
195
+ text_projection: ProjectionHead
196
+ projection head used for poem/text embeddings (projects text encoder output to shared embedding space)
197
+ image_projection: ProjectionHead
198
+ projection head used for image embeddings (projects image encoder output to shared embedding space)
199
+ temperature: float
200
+ used to scale the dot similarities
201
+
202
+ Methods:
203
+ --------
204
+ forward(batch):
205
+ returns poem/text and image embeddings of batch
206
+ similarity_scores(batch):
207
+ computes dot similarities of a batch of text-image pair
208
+ predict(batch):
209
+ predicts the most similar poem/text idx for each image (using previous methods)
210
+ calculate_loss(batch):
211
+ computes contrastive (cross entropy) loss for both poems/texts and images.
212
+ save_current():
213
+ saves current model's encoders (if trainable) and projection heads.
214
+ """
215
+ def __init__(
216
+ self,
217
+ image_encoder_pretrained,
218
+ text_encoder_pretrained,
219
+ text_projection_trainable,
220
+ temperature=CFG.temperature,
221
+ image_embedding=CFG.image_embedding,
222
+ text_embedding=CFG.text_embedding,
223
+ is_image_poem_pair=True
224
+ ):
225
+ """
226
+ Initializes model's submodules
227
+ Parameters:
228
+ -----------
229
+ image_encoder_pretrained: bool
230
+ whether or not to load a pretrained image encoder.
231
+ text_encoder_pretrained: bool
232
+ whether or not to load a pretrained text encoder.
233
+ text_projection_trainable: bool
234
+ whether or not to train text projection
235
+ (since the text projection is frozen in our trainings unlike other projections of models)
236
+ temperature: float, optional
237
+ used to scale the dot similarities
238
+ image_embedding: int, optional
239
+ dim of image encoder's encoding output before projection
240
+ text_embedding: int, optional
241
+ dim of text encoder's encoding output before projection
242
+ is_image_poem_pair: bool, optional
243
+ if True, the text inputs to this model is poems and needs one of the poem encoders to predict embeddings with.
244
+ else it's a text that needs the encoders dedicated to text.
245
+ """
246
+ super().__init__()
247
+ # Loading the encoders and their projections using configs
248
+ self.image_encoder = ImageEncoder(pretrained=image_encoder_pretrained, trainable=CFG.image_encoder_trainable)
249
+
250
+ if is_image_poem_pair:
251
+ self.encoder = TextEncoder(CFG.poem_encoder_model, CFG.poem_encoder_pretrained_name, pretrained=text_encoder_pretrained, trainable=CFG.poem_encoder_trainable)
252
+ self.text_projection = ProjectionHead(embedding_dim=text_embedding)
253
+ if CFG.poem_projection_load_path:
254
+ self.text_projection.load_state_dict(torch.load(CFG.poem_projection_load_path, map_location=CFG.device))
255
+ else:
256
+ self.encoder = TextEncoder(CFG.text_encoder_model, CFG.text_encoder_pretrained_name, pretrained=text_encoder_pretrained, trainable=CFG.text_encoder_trainable)
257
+ self.text_projection = ProjectionHead(embedding_dim=text_embedding)
258
+ if CFG.text_projection_load_path:
259
+ self.text_projection.load_state_dict(torch.load(CFG.text_projection_load_path, map_location=CFG.device))
260
+
261
+ self.image_projection = ProjectionHead(embedding_dim=image_embedding)
262
+ if CFG.image_projection_load_path:
263
+ self.image_projection.load_state_dict(torch.load(CFG.image_projection_load_path, map_location=CFG.device))
264
+
265
+ if not text_projection_trainable:
266
+ for p in self.text_projection.parameters():
267
+ p.requires_grad = False
268
+
269
+ self.text_projection_trainable = text_projection_trainable
270
+ self.is_image_poem_pair = is_image_poem_pair
271
+ self.temperature = temperature
272
+
273
+ def forward(self, batch):
274
+ """
275
+ returns image and text/poem embeddings of batch
276
+
277
+ Parameters:
278
+ -----------
279
+ batch: list of dict
280
+ input (containing image-text/poem pairs (text/poem encoded using the encoder's tokenizer)
281
+ with keys 'image' and 'text')
282
+
283
+ Returns:
284
+ --------
285
+ poem/text and image embeddings of batch (each of shape (batch_size, projection_dim))
286
+ """
287
+ image, texts = batch["image"], batch["text"]
288
+ # Getting Image and Text Features
289
+ image_features = self.image_encoder(batch["image"])
290
+ text_features = self.encoder(
291
+ input_ids=texts["input_ids"], attention_mask=texts["attention_mask"]
292
+ )
293
+ # Getting Image and Text Embeddings (with same dimension)
294
+ image_embeddings = self.image_projection(image_features)
295
+ text_embeddings = self.text_projection(text_features)
296
+
297
+ return image_embeddings, text_embeddings
298
+
299
+ def similarity_scores(self, batch):
300
+ """
301
+ computes dot similarities of a batch of text/poem-image pair
302
+
303
+ Parameters:
304
+ -----------
305
+ batch: list of dict
306
+ input (containing image-text/poem pairs (text/poem encoded using the encoder's tokenizer)
307
+ with keys 'image' and 'text')
308
+
309
+ Returns:
310
+ --------
311
+ dot similarity of poem/text and image embeddings of batch (of shape (batch_size, batch_size))
312
+ """
313
+ # Getting Image and Text Embeddings (with same dimension)
314
+ image_embeddings, text_embeddings = self.forward(batch)
315
+ # Normalizing embeddings
316
+ image_embeddings_n = F.normalize(image_embeddings, p=2, dim=-1)
317
+ text_embeddings_n = F.normalize(text_embeddings, p=2, dim=-1)
318
+ # Computing dot / cosine similarity of the normalized embeddings
319
+ dot_similarity = image_embeddings_n @ text_embeddings_n.T
320
+ return dot_similarity # (batch_size, batch_size) first dim is images, second dim is poems/texts for each image
321
+
322
+ def predict(self, batch):
323
+ """
324
+ predicts the most similar poem/text (idx) for each image (using previous methods)
325
+
326
+ Parameters:
327
+ -----------
328
+ batch: list of dict
329
+ input (containing image-text/poem pairs (text/poem encoded using the encoder's tokenizer)
330
+ with keys 'image' and 'text')
331
+
332
+ Returns:
333
+ --------
334
+ index of poem/text predicted for each image (of shape (batch_size))
335
+ """
336
+ dot_similarity = self.similarity_scores(batch)
337
+ # Getting argmax in first dimension of the dot-similarities
338
+ # to predict index of the most similar poem/text for each image
339
+ return torch.argmax(dot_similarity, dim=1)
340
+
341
+ def calculate_loss(self, image_embeddings, text_embeddings):
342
+ """
343
+ computes contrastive (cross entropy) loss for both poems/texts and images.
344
+
345
+ Parameters:
346
+ -----------
347
+ image_embeddings: of shape (batch_size, projection_dim)
348
+ output embeddings of image projection head
349
+ text_embeddings: of shape (batch_size, projection_dim)
350
+ output embeddings of text projection head
351
+
352
+ Returns:
353
+ --------
354
+ average of the loss computed from inputs
355
+ """
356
+ # dot similarity of the embeddings scaled by temperature (logits)
357
+ logits = (text_embeddings @ image_embeddings.T) / self.temperature
358
+ # computing targets for the cross entropy loss to compare with logits.
359
+ # each embedding's similarity is computed with itself and then averaged,
360
+ # scaled by the temperature parameter, and normalized into a probability distribution via a softmax
361
+ images_similarity = image_embeddings @ image_embeddings.T
362
+ texts_similarity = text_embeddings @ text_embeddings.T
363
+ targets = F.softmax(
364
+ (images_similarity + texts_similarity) / 2 * self.temperature, dim=-1
365
+ )
366
+ # taking cross entropy loss in both dimensions: once for texts and once for images
367
+ texts_loss = cross_entropy(logits, targets, reduction='none')
368
+ images_loss = cross_entropy(logits.T, targets.T, reduction='none')
369
+ loss = (images_loss + texts_loss) / 2.0 # average of losses. shape: (batch_size)
370
+ return loss.mean()
371
+
372
+ def save_current(self):
373
+ """
374
+ saves current model's encoders and projection heads (if trainable).
375
+ """
376
+ if self.is_image_poem_pair:
377
+ if CFG.poem_encoder_trainable:
378
+ self.encoder.model.save_pretrained(CFG.poem_encoder_save_path)
379
+ else:
380
+ if CFG.text_encoder_trainable:
381
+ self.encoder.model.save_pretrained(CFG.text_encoder_save_path)
382
+ if CFG.image_encoder_trainable:
383
+ torch.save(self.image_encoder.model.state_dict(), CFG.image_encoder_weights_save_path)
384
+ if self.text_projection_trainable:
385
+ torch.save(self.text_projection.state_dict(), CFG.text_projection_save_path)
386
+ torch.save(self.image_projection.state_dict(), CFG.image_projection_save_path)
387
+
388
+ def cross_entropy(preds, targets, reduction='none'):
389
+ """
390
+ Computes cross_entropy of logits and targets using their last dimension
391
+
392
+ Parameters:
393
+ -----------
394
+ preds: tensor/numpy array
395
+ logits
396
+ targets: tensor/ numpy array
397
+ reduction: str, optional
398
+ if set to "mean", return loss mean across all dimensions.
399
+ if set to "none", return loss computed using last dim.
400
+
401
+ Returns:
402
+ --------
403
+ loss or loss average
404
+ """
405
+ log_softmax = nn.LogSoftmax(dim=-1)
406
+ loss = (-targets * log_softmax(preds)).sum(1) # cross entropy loss
407
+ if reduction == "none":
408
+ return loss
409
+ elif reduction == "mean":
410
+ return loss.mean()
modules.py ADDED
@@ -0,0 +1,199 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ import timm
4
+ import config as CFG
5
+
6
+
7
+ class TextEncoder(nn.Module):
8
+ """
9
+ Text/Poem encoder used in PoemTextModel and CLIPModel
10
+ ...
11
+ Attributes:
12
+ -----------
13
+ model : a torch.nn.Module model
14
+ The image encoder model
15
+
16
+ Methods:
17
+ --------
18
+ forward(x)
19
+ returns model embeddings of x (batch of texts/poems) (of the CLS token)
20
+ __init__()
21
+ creates the encoder model using huggingface transformers,
22
+ also freezes the model if it's not trainable.
23
+ """
24
+ def __init__(self, encoder_model, encoder_pretrained_name, pretrained, trainable):
25
+ """
26
+ creates the poem or text encoder model using transformers and loads weights from pretrained model if needed.
27
+ Also freezes the model if it's not trainable.
28
+
29
+ Parameters:
30
+ -----------
31
+ pretrained: bool
32
+ if pretrained=True, get pretrained model's weights. else create a fresh untrained model.
33
+ trainable: bool
34
+ if trainable=False, the model's weights will be frozen.
35
+ encoder_model: str
36
+ image encoder model name used as input to get the right model from configs.
37
+ encoder_pretrained_name: str
38
+ image encoder model to get weights from. (not used when pretrained=False)
39
+ """
40
+ super().__init__()
41
+
42
+ if pretrained:
43
+ self.model = CFG.encoders[encoder_model].from_pretrained(encoder_pretrained_name)
44
+ else:
45
+ self.model = CFG.encoders[encoder_model](config=CFG.configs[encoder_model]())
46
+
47
+ for p in self.model.parameters():
48
+ p.requires_grad = trainable
49
+
50
+ # Using the CLS token hidden representation as the sentence's embedding
51
+ self.target_token_idx = 0
52
+
53
+ def forward(self, input_ids, attention_mask):
54
+ """
55
+ forwards and calculates embeddings of the input using attention mask.
56
+
57
+ Parameters:
58
+ -----------
59
+ input_ids: input ids (output of tokenizer)
60
+ attention masks: input masks (for example for padding, pad tokens will be masked)
61
+
62
+ Returns:
63
+ --------
64
+ the embedding of the CLS (or target) token of the encoder's last hidden state
65
+ """
66
+ output = self.model(input_ids=input_ids, attention_mask=attention_mask)
67
+ last_hidden_state = output.last_hidden_state
68
+ return last_hidden_state[:, self.target_token_idx, :]
69
+
70
+
71
+
72
+ class ProjectionHead(nn.Module):
73
+ """
74
+ Projection head used to project embeddings from each encoder to a shared embedding space
75
+ ...
76
+ Attributes:
77
+ -----------
78
+ projection : torch.nn.Linear
79
+ The main Dense projection (from encoder's embedding dim to shared embedding projection dim)
80
+ gelu: torch.nn.GELU
81
+ activation function
82
+ fc: torch.nn.Linear
83
+ a dense layer after projection (projection_dim to projection_dim)
84
+ dropout: torch.nn.Dropout
85
+ dropout after fc
86
+ layer_norm: torch.nn.LayerNorm
87
+ layer norm after dropout
88
+
89
+ Methods:
90
+ --------
91
+ forward(x)
92
+ returns projection embeddings from x (encoder output embeddings)
93
+ __init__()
94
+ creates the projection head
95
+ """
96
+ def __init__(
97
+ self,
98
+ embedding_dim,
99
+ projection_dim=CFG.projection_dim,
100
+ dropout=CFG.dropout
101
+ ):
102
+ """
103
+ Creates the projection head used after an encoder.
104
+
105
+ Parameters:
106
+ -----------
107
+ embedding_dim: int
108
+ dimension of the output embeddings of the encoder.
109
+ projection_dim: int, optional
110
+ dimension to project embeddings to.
111
+ dropout: float
112
+ fraction of the output of fc layer to be zeroed.
113
+ """
114
+ super().__init__()
115
+ self.projection = nn.Linear(embedding_dim, projection_dim)
116
+ self.gelu = nn.GELU()
117
+ self.fc = nn.Linear(projection_dim, projection_dim)
118
+ self.dropout = nn.Dropout(dropout)
119
+ self.layer_norm = nn.LayerNorm(projection_dim)
120
+
121
+ def forward(self, x):
122
+ """
123
+ Forwards and calculates projected embeddings from encoder embeddings.
124
+
125
+ Parameters:
126
+ -----------
127
+ x: input (of shape (batch_size, embedding_dim))
128
+ the output embedding of this projection head's encoder
129
+
130
+ Returns:
131
+ --------
132
+ the embeddings in a shared embedding space (of shape (batch_size, projection_dim))
133
+ """
134
+ projected = self.projection(x) #main projection layer
135
+ x = self.gelu(projected)
136
+ x = self.fc(x)
137
+ x = self.dropout(x)
138
+ # the projected outputs are added to x as a residual connection
139
+ x = x + projected
140
+ x = self.layer_norm(x)
141
+ return x
142
+
143
+
144
+ class ImageEncoder(nn.Module):
145
+ """
146
+ Image encoder used in CLIPModel
147
+ ...
148
+ Attributes:
149
+ -----------
150
+ model : a torch.nn.Module model from timm (pytorch-image-models)
151
+ The image encoder model
152
+
153
+ Methods:
154
+ --------
155
+ forward(x)
156
+ returns model embeddings of x (batch of images)
157
+ __init__()
158
+ creates the encoder model using timm and loads fine-tuned model's state dict if needed.
159
+ also freezes the model if it's not trainable.
160
+ """
161
+ def __init__(
162
+ self, pretrained, trainable, model_name=CFG.image_encoder_model
163
+ ):
164
+ """
165
+ creates the encoder model using timm and loads fine-tuned model's state dict if needed.
166
+ Also freezes the model if it's not trainable.
167
+
168
+ Parameters:
169
+ -----------
170
+ pretrained: bool
171
+ if pretrained=True, get SOTA weights (or weights saved in image_encoder_weights_load_path).
172
+ else create a fresh untrained model.
173
+ trainable: bool
174
+ if trainable=False, the model's weights will be frozen.
175
+ model_name: str
176
+ image encoder model name used as input to timm.create_model.
177
+ """
178
+ super().__init__()
179
+ self.model = timm.create_model(
180
+ model_name, pretrained, num_classes=0, global_pool="avg"
181
+ )
182
+ if pretrained and CFG.image_encoder_weights_load_path:
183
+ self.model.load_state_dict(torch.load(CFG.image_encoder_weights_load_path, map_location=CFG.device))
184
+ for p in self.model.parameters():
185
+ p.requires_grad = trainable
186
+
187
+ def forward(self, x):
188
+ """
189
+ forwards and calculates embeddings of the input.
190
+
191
+ Parameters:
192
+ -----------
193
+ x: input (batch of transformed images)
194
+
195
+ Returns:
196
+ --------
197
+ embeddings of the model for the input (of shape (batch_size, image_embedding))
198
+ """
199
+ return self.model(x)
projections/LaBSE_best_text_projection.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:42369217ef5104e0ccf452ad310b2d2dcfc81d20d6444532d70c44bb064e76d8
3
+ size 7358959
projections/ParsBERT_best_poem_projection.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:953022eab4908ab16e512446c11e7edf32a2ec8e7379de0d6748d52e7dda9773
3
+ size 7358983
requirements.txt ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ numpy
2
+ pandas
3
+ Pillow
4
+ scikit_learn
5
+ torch
6
+ torchvision
7
+ tqdm
8
+ transformers
9
+ timm
10
+ opencv-python
11
+ albumentations
12
+ gradio
train.py ADDED
@@ -0,0 +1,202 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import gc
3
+ import numpy as np
4
+ import pandas as pd
5
+ from tqdm import tqdm
6
+ import random
7
+ import json
8
+
9
+
10
+ import torch
11
+ from torch import nn
12
+
13
+ #FIX
14
+ import config as CFG
15
+ from models import CLIPModel
16
+ from utils import AvgMeter, get_lr
17
+ from utils import get_datasets, build_loaders
18
+
19
+ def train_epoch(model, train_loader, optimizer, lr_scheduler, step):
20
+ """
21
+ Performs one epoch of training.
22
+
23
+ Parameters:
24
+ -----------
25
+ model: PoemTextModel or CLIPModel
26
+ model to train
27
+ train_loader: torch.utils.data.DataLoader
28
+ dataloader to get batches from
29
+ optimizer: torch.optim.Optimizer
30
+ optimizer used for training
31
+ lr_scheduler: torch.optim.lr_scheduler.LRScheduler
32
+ scheduler used for training
33
+ step: str ("batch" or "epoch")
34
+ if "batch", lr_scheduler will step (update) for each batch of loader.
35
+ else lr_scheduler only steps and updates after finishing each epoch.
36
+
37
+ Returns:
38
+ --------
39
+ loss_meter: AvgMeter
40
+ the class containing average loss of this epoch's training
41
+ """
42
+ loss_meter = AvgMeter() # to track average of loss
43
+ tqdm_object = tqdm(train_loader, total=len(train_loader))
44
+ for batch_cpu in tqdm_object:
45
+ # put batch data on device
46
+ batch = {k: {dict_k: dict_v.to(CFG.device) for dict_k, dict_v in v.items()} for k, v in batch_cpu.items() if not k in ["id", "image"]}
47
+ if "image" in batch_cpu:
48
+ batch["image"] = batch_cpu["image"].to(CFG.device)
49
+
50
+ #get model's embeddings and calculate loss
51
+ poem_or_img_embeddings, text_embeddings = model(batch)
52
+ loss = model.calculate_loss(poem_or_img_embeddings, text_embeddings)
53
+
54
+ # backpropagate and step
55
+ optimizer.zero_grad()
56
+ loss.backward()
57
+ optimizer.step()
58
+ if step == "batch":
59
+ lr_scheduler.step()
60
+
61
+ #update training info
62
+ count = batch["text"]["input_ids"].size(0)
63
+ loss_meter.update(loss.item(), count)
64
+ tqdm_object.set_postfix(train_loss=loss_meter.avg, lr=get_lr(optimizer))
65
+ # print('train loss: ', loss_meter.avg)
66
+ return loss_meter
67
+
68
+
69
+ def valid_epoch(model, valid_loader):
70
+ """
71
+ Performs one epoch of validation.
72
+
73
+ Parameters:
74
+ -----------
75
+ model: PoemTextModel or CLIPModel
76
+ model to validate
77
+ valid_loader: torch.utils.data.DataLoader
78
+ dataloader to get batches from.
79
+
80
+ Returns:
81
+ --------
82
+ loss_meter: AvgMeter
83
+ the class containing average loss of this epoch's validation
84
+ """
85
+ loss_meter = AvgMeter() # to track average of loss
86
+ tqdm_object = tqdm(valid_loader, total=len(valid_loader))
87
+ for batch_cpu in tqdm_object:
88
+ # put batch data on device
89
+ batch = {k: {dict_k: dict_v.to(CFG.device) for dict_k, dict_v in v.items()} for k, v in batch_cpu.items() if not k in ["id", "image"]}
90
+ if "image" in batch_cpu:
91
+ batch["image"] = batch_cpu["image"].to(CFG.device)
92
+
93
+ #get model's embeddings and calculate loss
94
+ poem_or_img_embeddings, text_embeddings = model(batch)
95
+ loss = model.calculate_loss(poem_or_img_embeddings, text_embeddings)
96
+
97
+ #update validation info
98
+ count = batch["text"]["input_ids"].size(0)
99
+ loss_meter.update(loss.item(), count)
100
+ tqdm_object.set_postfix(valid_loss=loss_meter.avg)
101
+ # print('validation loss: ', loss_meter.avg)
102
+ return loss_meter
103
+
104
+ def test(model, test_dataset):
105
+ """
106
+ Calculates accuracy on test set.
107
+ This method is used for the PoemTextModel, since the other model (CLIPModel) does not have a test set containing pairs of image-poem.
108
+
109
+ Parameters:
110
+ -----------
111
+ model: PoemTextModel
112
+ model to test
113
+ test_dataset: list of dict
114
+ the list containing dict of data to perform test on (must have "text" and "poem" keys)
115
+
116
+ Returns:
117
+ --------
118
+ accuracy: np.float
119
+ The accuracy of model on the test set given
120
+ """
121
+ test_loader = build_loaders(test_dataset, mode="test")
122
+ accuracy = 0
123
+ tqdm_object = tqdm(test_loader, total=len(test_loader))
124
+ model.eval()
125
+ with torch.no_grad():
126
+ for batch_cpu in tqdm_object:
127
+ # put batch data on device
128
+ batch = {k: {dict_k: dict_v.to(CFG.device) for dict_k, dict_v in v.items()} for k, v in batch_cpu.items() if not k in ["id", "image"]}
129
+ if "image" in batch_cpu:
130
+ batch["image"] = batch_cpu["image"].to(CFG.device)
131
+
132
+ # get model's prediction for each text (a numpy array of index/labels showing which poem belongs to which text)
133
+ pred = model.predict(batch).cpu().numpy()
134
+
135
+ count = batch["text"]["input_ids"].size(0)
136
+ # since each text is associated with the poem with the same index as it, np.arange(count) is the real labels.
137
+ acc = np.sum(pred == np.arange(count))
138
+ accuracy += acc
139
+
140
+ tqdm_object.set_postfix(accuracy=acc / count)
141
+ accuracy /= len(test_dataset)
142
+ return accuracy
143
+
144
+ def train(model, train_loader, valid_loader, epochs=CFG.epochs):
145
+ """
146
+ Performs train and validation for (epochs) epochs.
147
+
148
+ Parameters:
149
+ -----------
150
+ model: PoemTextModel or CLIPModel
151
+ model to train
152
+ train_loader: torch.utils.data.DataLoader
153
+ train dataloader to get batches from
154
+ valid_loader: torch.utils.data.DataLoader
155
+ validation dataloader to get batches from
156
+ epochs: int, optional
157
+ the number of epochs to train
158
+
159
+ Returns:
160
+ --------
161
+ model: PoemTextModel or CLIPModel
162
+ trained model
163
+ loss_history: dict
164
+ a dict containing train and validation average loss for each epoch.
165
+ """
166
+ # Using AdamW optimizer and ReduceLROnPlateau lr-scheduler with settings from config
167
+ optimizer = torch.optim.AdamW(
168
+ model.parameters(), lr=CFG.lr, weight_decay=CFG.weight_decay
169
+ )
170
+ lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
171
+ optimizer, mode="min", patience=CFG.patience, factor=CFG.factor
172
+ )
173
+
174
+ # if step="batch", lr_scheduler will step (update) for each batch of loader.
175
+ # else lr_scheduler only steps and updates after finishing each epoch. (this case)
176
+ step = "epoch"
177
+ loss_history = {"train":[], "valid":[]}
178
+
179
+ # to keep track of best validation loss
180
+ best_loss = float('inf')
181
+ for epoch in range(CFG.epochs):
182
+ print(f"Epoch: {epoch + 1}")
183
+ # train for one epoch
184
+ model.train()
185
+ train_loss = train_epoch(model, train_loader, optimizer, lr_scheduler, step)
186
+ loss_history["train"].append(train_loss.avg)
187
+
188
+ # validate trained model
189
+ model.eval()
190
+ with torch.no_grad():
191
+ valid_loss = valid_epoch(model, valid_loader)
192
+ loss_history["valid"].append(valid_loss.avg)
193
+
194
+ # if this epoch's avg validation loss is lower than best loss, save and keep this model.
195
+ if valid_loss.avg < best_loss:
196
+ best_loss = valid_loss.avg
197
+ model.save_current()
198
+ print("Saved Best Model!")
199
+
200
+ if step == "epoch":
201
+ lr_scheduler.step(valid_loss.avg)
202
+ return model, loss_history
utils.py ADDED
@@ -0,0 +1,207 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import config as CFG
2
+ import json
3
+ from models import PoemTextModel
4
+ import torch
5
+ import random
6
+ from datasets import PoemTextDataset, get_transforms, CLIPDataset
7
+ from tqdm import tqdm
8
+ import numpy as np
9
+
10
+ class AvgMeter:
11
+ """
12
+ Used to keep track of batch losses during training / validation.
13
+ ...
14
+ Attributes:
15
+ -----------
16
+ name : str
17
+ count : int
18
+ number of data whose train/val loss has been metered
19
+ sum: int or float
20
+ sum of all losses metered
21
+ avg: int or float
22
+ average of metered losses
23
+
24
+ Methods:
25
+ --------
26
+ reset():
27
+ Sets count, sum and avg to 0.
28
+ update(val, count=1):
29
+ Updates loss sum, count and avg.
30
+ __repr__():
31
+ string representation of this class.
32
+ """
33
+ def __init__(self, name="Metric"):
34
+ """Sets the name of the avg meter. sets avg, sum & count to 0."""
35
+ self.name = name
36
+ self.reset()
37
+
38
+ def reset(self):
39
+ """Sets avg, sum & count to 0."""
40
+ self.avg, self.sum, self.count = [0] * 3
41
+
42
+ def update(self, val, count=1):
43
+ """Updates loss sum, count and avg using val and count (count of the val input)"""
44
+ self.count += count
45
+ self.sum += val * count
46
+ self.avg = self.sum / self.count
47
+
48
+ def __repr__(self):
49
+ """String representation of this class"""
50
+ text = f"{self.name}: {self.avg:.4f}"
51
+ return text
52
+
53
+ def get_lr(optimizer):
54
+ """Returns learning rate of the input optimizer"""
55
+ for param_group in optimizer.param_groups:
56
+ return param_group["lr"]
57
+
58
+ def get_datasets():
59
+ """
60
+ Returns train, validation & test split from a dataset json file specified using CFG.dataset_path.
61
+ This function first loads the file into a list of dict and shuffles them with CFG.random_seed seed,
62
+ then splits them using CFG.train_propotion & CFG.val_propotion.
63
+
64
+ Returns:
65
+ --------
66
+ train_dataset: list of dict
67
+ Train split
68
+ val_dataset: list of dict
69
+ Validation split
70
+ test_dataset: list of dict
71
+ Test split
72
+ """
73
+ with open(CFG.dataset_path, encoding="utf-8") as f:
74
+ dataset = json.load(f)
75
+ random.Random(CFG.random_seed).shuffle(dataset)
76
+ # https://stackoverflow.com/questions/38250710/how-to-split-data-into-3-sets-train-validation-and-test
77
+ train_dataset, val_dataset, test_dataset = np.split(dataset,
78
+ [int(CFG.train_propotion*len(dataset)), int((CFG.train_propotion + CFG.val_propotion)*len(dataset))])
79
+ return train_dataset, val_dataset, test_dataset
80
+
81
+
82
+ def build_loaders(dataset_dict, mode):
83
+ """
84
+ Returns a torch Dataloader from a list of dictionaries (dataset_dict).
85
+ First makes a PoemTextDataset which is a torch Dataset object from dataset_dict and then instantiates a Dataloader.
86
+
87
+ Parameters:
88
+ -----------
89
+ dataset_dict: list of dict
90
+ the dataset to return a dataloader of.
91
+ mode: str ("train" or any other word)
92
+ if the mode is "train", dataloader will activate shuffling.
93
+
94
+ Returns:
95
+ --------
96
+ dataloader: torch.utils.data.DataLoader
97
+ the torch Dataloader created from dataset_dict using PoemTextDataset and configs.
98
+ """
99
+ dataset = PoemTextDataset(
100
+ dataset_dict
101
+ )
102
+ dataloader = torch.utils.data.DataLoader(
103
+ dataset,
104
+ batch_size=CFG.batch_size,
105
+ num_workers=CFG.num_workers,
106
+ shuffle=True if mode == "train" else False,
107
+ )
108
+ return dataloader
109
+
110
+ def get_clip_datasets(dataset_dict):
111
+ """
112
+ (Used for clip model training) Returns train, validation & test split from input.
113
+ This function takes a list of dict as dataset and shuffles them with CFG.random_seed seed,
114
+ then splits them using CFG.train_propotion & CFG.val_propotion.
115
+
116
+ Parameters:
117
+ -----------
118
+ dataset_dict: list of dict
119
+ the input dataset
120
+ Returns:
121
+ --------
122
+ train_dataset: list of dict
123
+ Train split
124
+ val_dataset: list of dict
125
+ Validation split
126
+ test_dataset: list of dict
127
+ Test split
128
+ """
129
+ random.Random(CFG.random_seed).shuffle(dataset_dict)
130
+ # https://stackoverflow.com/questions/38250710/how-to-split-data-into-3-sets-train-validation-and-test
131
+ train_dataset, val_dataset, test_dataset = np.split(dataset_dict,
132
+ [int(CFG.train_propotion*len(dataset_dict)), int((CFG.train_propotion + CFG.val_propotion)*len(dataset_dict))])
133
+ return train_dataset, val_dataset, test_dataset
134
+
135
+
136
+ def build_image_loaders(dataset_dict, mode):
137
+ """
138
+ (Used for clip model training) Returns a torch Dataloader from a list of dictionaries (dataset_dict).
139
+ First makes a PoemTextDataset which is a torch Dataset object from dataset_dict and then instantiates a Dataloader.
140
+
141
+ Parameters:
142
+ -----------
143
+ dataset_dict: list of dict
144
+ the dataset to return a dataloader of.
145
+ mode: str ("train" or any other word)
146
+ if the mode is "train", dataloader will activate shuffling.
147
+
148
+ Returns:
149
+ --------
150
+ dataloader: torch.utils.data.DataLoader
151
+ the torch Dataloader created from dataset_dict using CLIPDataset and configs.
152
+ """
153
+ transforms = get_transforms(mode=mode)
154
+ dataset = CLIPDataset(
155
+ dataset_dict, transforms, is_image_poem_pair=False
156
+ )
157
+ dataloader = torch.utils.data.DataLoader(
158
+ dataset,
159
+ batch_size=CFG.batch_size,
160
+ num_workers=CFG.num_workers,
161
+ shuffle=True if mode == "train" else False,
162
+ )
163
+ return dataloader
164
+
165
+ def get_poem_embeddings(test_dataset, model=None):
166
+ """
167
+ Returns embeddings of the poems existing in test_dataset.
168
+
169
+ Parameters:
170
+ -----------
171
+ test_dataset: list of dict
172
+ dataset to get poems from. each of its dictionaries must have a "beyt" key.
173
+ model: PoemTextModel, optional
174
+ The PoemTextModel model to get poem embeddings from.
175
+ If None is given, instantiates a new model (with all of its parts in pretrained settings) using configurations provided in config.py.
176
+
177
+ Returns:
178
+ --------
179
+ model (PoemTextModel): The model used for creating poem embeddings
180
+ """
181
+ test_loader = build_loaders(test_dataset, mode="test") # building a dataloder (which also tokenizes the poems)
182
+
183
+ if model == None:
184
+ model = PoemTextModel(True, False, True, False, poem_projection_pretrained=True, text_projection_pretrained=True).to(CFG.device)
185
+ model.eval()
186
+
187
+ poem_embeddings = []
188
+ with torch.no_grad():
189
+ for batch in tqdm(test_loader):
190
+ # get poem embeddings by passing tokenizer output of the poems
191
+ # to the model's poem encoder and projection
192
+ beyts = {
193
+ key: values.to(CFG.device)
194
+ for key, values in batch["beyt"].items()
195
+ }
196
+ if model.__class__.__name__ == "PoemTextModel":
197
+ poem_features = model.poem_encoder(input_ids=beyts["input_ids"], attention_mask=beyts["attention_mask"])
198
+ poem_emb = model.poem_projection(poem_features)
199
+ poem_embeddings.append(poem_emb)
200
+ elif model.__class__.__name__ == "CLIPModel":
201
+ poem_features = model.encoder(input_ids=beyts["input_ids"], attention_mask=beyts["attention_mask"])
202
+ poem_emb = model.text_projection(poem_features)
203
+ poem_embeddings.append(poem_emb)
204
+ else:
205
+ raise #not a right model to use!
206
+
207
+ return model, torch.cat(poem_embeddings)