File size: 8,494 Bytes
2fa2727
 
 
 
 
 
 
 
 
 
 
 
 
 
fd6aade
2fa2727
 
fd6aade
2fa2727
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fd6aade
 
2fa2727
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fd6aade
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2fa2727
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fd6aade
 
2fa2727
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fd6aade
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2fa2727
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
from __future__ import annotations
import torch
import cv2
import torch.nn.functional as F
import numpy as np
import config as CFG
from datasets import get_transforms

#for running this script as main
from utils import get_datasets, build_loaders
from models import PoemTextModel
from utils import get_poem_embeddings
import json
import os
import regex


def predict_poems_from_text(model, poem_embeddings, query, poems, text_tokenizer=None, n=10, return_similarities=False):
    """
    Returns n poems which are the most similar to a text query

        Parameters:
        -----------
            model: PoemTextModel
                model to compute text query's embeddings
            poem_embeddings: sequence with shape (#poems, CFG.projection_dim)
                poem embeddings to check similarity
            query: str
                text query
            poems: list of str
                poems corresponding to poem_embeddings
            text_tokenizer: huggingface Tokenizer, optional
                tokenizer to tokenize query with. if none, will instantiate a new text tokenizer using configs.
            n: int, optional
                number of poems to return
            return_similarities: bool, optional
                if True, a dictionary will be returned which has the poem beyts and their similarities to the text

        Returns:
        --------
            A list of n poem strings whose embeddings are the most similar to query text's embedding.  

    """
    #Tokenizing and Encoding the query text
    if not text_tokenizer:
        text_tokenizer = CFG.tokenizers[CFG.text_encoder_model].from_pretrained(CFG.text_tokenizer)

    encoded_query = text_tokenizer([query])
    batch = {
        key: torch.tensor(values).to(CFG.device)
        for key, values in encoded_query.items()
    }

    # getting query text's embeddings
    model.eval()
    with torch.no_grad():
        text_features = model.text_encoder(
            input_ids= batch["input_ids"], attention_mask=batch["attention_mask"]
        )
        text_embeddings = model.text_projection(text_features)
    
    # normalizing and computing dot similarity of poem and text embeddings
    poem_embeddings_n = F.normalize(poem_embeddings, p=2, dim=-1)
    text_embeddings_n = F.normalize(text_embeddings, p=2, dim=-1)

    dot_similarity = text_embeddings_n @ poem_embeddings_n.T

    # returning top n poems based on embedding similarity
    values, indices = torch.topk(dot_similarity.squeeze(0), len(poems))

    # since we collected poems from many sources, some of them are equal (the same beyt with different meanings),
    # so we must check the poems added to result not to be duplicates
    def is_poem_duplicate(poem, poems):
        poem = regex.findall(r'\p{L}+', poem.replace('\u200c', ''))
        for other_poem in poems:
            other_poem = regex.findall(r'\p{L}+', other_poem.replace('\u200c', ''))
            if poem == other_poem:
                return True
        return False
    
    results = []
    computed_k = 0
    for i in range(len(poems)):
        if computed_k == n:
            break
        if not is_poem_duplicate(poems[indices[i]], [res['beyt'] for res in results]):
            results.append({
                'beyt': poems[indices[i]].replace(' * * ', ' * ').replace('*** * ', ''),
                'similarity': values[i]
            })
            computed_k += 1
    if return_similarities:
        return results
    else:
        return [res['beyt'] for res in results]


def predict_poems_from_image(model, poem_embeddings, image_filename, poems, n=10, return_similarities=False):
    """
    Returns n poems which are the most similar to an image query

        Parameters:
        -----------
            model: CLIPModel
                model to compute image query's embeddings
            poem_embeddings: sequence with shape (#poems, CFG.projection_dim)
                poem embeddings to check similarity
            image_filename: str
                path and file name for the image query
            poems: list of str
                poems corresponding to poem_embeddings
            n: int, optional
                number of poems to return
            return_similarities: bool, optional
                if True, a dictionary will be returned which has the poem beyts and their similarities to the text

        Returns:
        --------
            A list of n poem strings whose embeddings are the most similar to image query's embedding.  

    """
    # Reading, Processing and applying transforms to image (all explained in datasets.py)
    image = cv2.imread(f"{image_filename}")
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    image = get_transforms(mode="test")(image=image)['image']
    image = torch.tensor(image).permute(2, 0, 1).float()

    # getting image query's embeddings
    model.eval()
    with torch.no_grad():
        image_features = model.image_encoder(torch.unsqueeze(image, 0).to(CFG.device))
        image_embeddings = model.image_projection(image_features)
    
    # normalizing and computing dot similarity of poem and text embeddings
    poem_embeddings_n = F.normalize(poem_embeddings, p=2, dim=-1)
    image_embeddings_n = F.normalize(image_embeddings, p=2, dim=-1)
    dot_similarity = image_embeddings_n @ poem_embeddings_n.T

    # returning top n poems based on embedding similarity
    values, indices = torch.topk(dot_similarity.squeeze(0), len(poems))
    
    # since we collected poems from many sources, some of them are equal (the same beyt with different meanings),
    # so we must check the poems added to result not to be duplicates
    def is_poem_duplicate(poem, poems):
        poem = regex.findall(r'\p{L}+', poem.replace('\u200c', ''))
        for other_poem in poems:
            other_poem = regex.findall(r'\p{L}+', other_poem.replace('\u200c', ''))
            if poem == other_poem:
                return True
        return False
    
    results = []
    computed_k = 0
    for i in range(len(poems)):
        if computed_k == n:
            break
        if not is_poem_duplicate(poems[indices[i]], [res['beyt'] for res in results]):
            results.append({
                'beyt': poems[indices[i]].replace(' * * ', ' * ').replace('*** * ', ''),
                'similarity': values[i]
            })
            computed_k += 1
    if return_similarities:
        return results
    else:
        return [res['beyt'] for res in results]


if __name__ == "__main__":
    """
    Creates a PoemTextModel based on configs, and outputs some examples of its prediction.
    """
    # get dataset from dataset_path (the same datasets as the train, val and test dataset files in the data directory is made)
    train_dataset, val_dataset, test_dataset = get_datasets()

    model = PoemTextModel(poem_encoder_pretrained=True, text_encoder_pretrained=True).to(CFG.device)
    model.eval()
    # Inference: Output some example predictions and write them in a file
    print("_"*20)
    print("Output Examples from test set")
    model, poem_embeddings = get_poem_embeddings(test_dataset, model)
    example = {}
    for i, test_data in enumerate(test_dataset[:100]):
        example[i] = {'Text': test_data["text"], 'True Beyt': test_data["beyt"], "Predicted Beyt":predict_poems_from_text(model, poem_embeddings, test_data["text"], [data['beyt'] for data in test_dataset], n=10)}
    for i in range(10):
        print("Text: ", example[i]['Text'])
        print("True Beyt: ", example[i]['True Beyt'])
        print("predicted Beyts: \n\t", "\n\t".join(example[i]["Predicted Beyt"]))
    with open('example_output__{}_{}.json'.format(CFG.poem_encoder_model, CFG.text_encoder_model),'w', encoding="utf-8") as f:
        f.write(json.dumps(example, ensure_ascii=False, indent= 4))
    
    print("Preparing model for user input...")
    with open(CFG.dataset_path, encoding="utf-8") as f:
        dataset = json.load(f)

    model, poem_embeddings = get_poem_embeddings(dataset, model)

    while(True):
        user_text = input("Enter a Text to find poem beyts for: ")
        beyts = predict_poems_from_text(model, poem_embeddings, user_text, [data['beyt'] for data in dataset], n=10)
        print("predicted Beyts: \n\t", "\n\t".join(beyts))
        with open('{}_output__{}_{}.json'.format(user_text, CFG.poem_encoder_model, CFG.text_encoder_model),'a+', encoding="utf-8") as f:
            f.write(json.dumps(beyts, ensure_ascii=False, indent= 4))