File size: 6,134 Bytes
2fa2727
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
from __future__ import annotations
import torch
import cv2
import torch.nn.functional as F
import numpy as np
import config as CFG
from datasets import get_transforms

#for running this script as main
from utils import get_datasets, build_loaders
from models import PoemTextModel
from utils import get_poem_embeddings
import json
import os


def predict_poems_from_text(model, poem_embeddings, query, poems, text_tokenizer=None, n=10):
    """
    Returns n poems which are the most similar to a text query

        Parameters:
        -----------
            model: PoemTextModel
                model to compute text query's embeddings
            poem_embeddings: sequence with shape (#poems, CFG.projection_dim)
                poem embeddings to check similarity
            query: str
                text query
            poems: list of str
                poems corresponding to poem_embeddings
            text_tokenizer: huggingface Tokenizer, optional
                tokenizer to tokenize query with. if none, will instantiate a new text tokenizer using configs.
            n: int, optional
                number of poems to return

        Returns:
        --------
            A list of n poem strings whose embeddings are the most similar to query text's embedding.  

    """
    #Tokenizing and Encoding the query text
    if not text_tokenizer:
        text_tokenizer = CFG.tokenizers[CFG.text_encoder_model].from_pretrained(CFG.text_tokenizer)

    encoded_query = text_tokenizer([query])
    batch = {
        key: torch.tensor(values).to(CFG.device)
        for key, values in encoded_query.items()
    }

    # getting query text's embeddings
    model.eval()
    with torch.no_grad():
        text_features = model.text_encoder(
            input_ids= batch["input_ids"], attention_mask=batch["attention_mask"]
        )
        text_embeddings = model.text_projection(text_features)
    
    # normalizing and computing dot similarity of poem and text embeddings
    poem_embeddings_n = F.normalize(poem_embeddings, p=2, dim=-1)
    text_embeddings_n = F.normalize(text_embeddings, p=2, dim=-1)

    dot_similarity = text_embeddings_n @ poem_embeddings_n.T

    # returning top n poems based on embedding similarity
    _, indices = torch.topk(dot_similarity.squeeze(0), n)
    return [poems[idx] for idx in indices]


def predict_poems_from_image(model, poem_embeddings, image_filename, poems, n=10):
    """
    Returns n poems which are the most similar to an image query

        Parameters:
        -----------
            model: CLIPModel
                model to compute image query's embeddings
            poem_embeddings: sequence with shape (#poems, CFG.projection_dim)
                poem embeddings to check similarity
            image_filename: str
                path and file name for the image query
            poems: list of str
                poems corresponding to poem_embeddings
            n: int, optional
                number of poems to return

        Returns:
        --------
            A list of n poem strings whose embeddings are the most similar to image query's embedding.  

    """
    # Reading, Processing and applying transforms to image (all explained in datasets.py)
    image = cv2.imread(f"{image_filename}")
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    image = get_transforms(mode="test")(image=image)['image']
    image = torch.tensor(image).permute(2, 0, 1).float()

    # getting image query's embeddings
    model.eval()
    with torch.no_grad():
        image_features = model.image_encoder(torch.unsqueeze(image, 0).to(CFG.device))
        image_embeddings = model.image_projection(image_features)
    
    # normalizing and computing dot similarity of poem and text embeddings
    poem_embeddings_n = F.normalize(poem_embeddings, p=2, dim=-1)
    image_embeddings_n = F.normalize(image_embeddings, p=2, dim=-1)
    dot_similarity = image_embeddings_n @ poem_embeddings_n.T

    # returning top n poems based on embedding similarity
    _, indices = torch.topk(dot_similarity.squeeze(0), n)
    return [poems[idx] for idx in indices]

if __name__ == "__main__":
    """
    Creates a PoemTextModel based on configs, and outputs some examples of its prediction.
    """
    # get dataset from dataset_path (the same datasets as the train, val and test dataset files in the data directory is made)
    train_dataset, val_dataset, test_dataset = get_datasets()

    model = PoemTextModel(poem_encoder_pretrained=True, text_encoder_pretrained=True).to(CFG.device)
    model.eval()
    # Inference: Output some example predictions and write them in a file
    print("_"*20)
    print("Output Examples from test set")
    model, poem_embeddings = get_poem_embeddings(test_dataset, model)
    example = {}
    for i, test_data in enumerate(test_dataset[:100]):
        example[i] = {'Text': test_data["text"], 'True Beyt': test_data["beyt"], "Predicted Beyt":predict_poems_from_text(model, poem_embeddings, test_data["text"], [data['beyt'] for data in test_dataset], n=10)}
    for i in range(10):
        print("Text: ", example[i]['Text'])
        print("True Beyt: ", example[i]['True Beyt'])
        print("predicted Beyts: \n\t", "\n\t".join(example[i]["Predicted Beyt"]))
    with open('example_output__{}_{}.json'.format(CFG.poem_encoder_model, CFG.text_encoder_model),'w', encoding="utf-8") as f:
        f.write(json.dumps(example, ensure_ascii=False, indent= 4))
    
    print("Preparing model for user input...")
    with open(CFG.dataset_path, encoding="utf-8") as f:
        dataset = json.load(f)

    model, poem_embeddings = get_poem_embeddings(dataset, model)

    while(True):
        user_text = input("Enter a Text to find poem beyts for: ")
        beyts = predict_poems_from_text(model, poem_embeddings, user_text, [data['beyt'] for data in dataset], n=10)
        print("predicted Beyts: \n\t", "\n\t".join(beyts))
        with open('{}_output__{}_{}.json'.format(user_text, CFG.poem_encoder_model, CFG.text_encoder_model),'a+', encoding="utf-8") as f:
            f.write(json.dumps(beyts, ensure_ascii=False, indent= 4))