from __future__ import annotations import torch import cv2 import torch.nn.functional as F import numpy as np import config as CFG from datasets import get_transforms #for running this script as main from utils import get_datasets, build_loaders from models import PoemTextModel from utils import get_poem_embeddings import json import os import regex def predict_poems_from_text(model, poem_embeddings, query, poems, text_tokenizer=None, n=10, return_similarities=False): """ Returns n poems which are the most similar to a text query Parameters: ----------- model: PoemTextModel model to compute text query's embeddings poem_embeddings: sequence with shape (#poems, CFG.projection_dim) poem embeddings to check similarity query: str text query poems: list of str poems corresponding to poem_embeddings text_tokenizer: huggingface Tokenizer, optional tokenizer to tokenize query with. if none, will instantiate a new text tokenizer using configs. n: int, optional number of poems to return return_similarities: bool, optional if True, a dictionary will be returned which has the poem beyts and their similarities to the text Returns: -------- A list of n poem strings whose embeddings are the most similar to query text's embedding. """ #Tokenizing and Encoding the query text if not text_tokenizer: text_tokenizer = CFG.tokenizers[CFG.text_encoder_model].from_pretrained(CFG.text_tokenizer) encoded_query = text_tokenizer([query]) batch = { key: torch.tensor(values).to(CFG.device) for key, values in encoded_query.items() } # getting query text's embeddings model.eval() with torch.no_grad(): text_features = model.text_encoder( input_ids= batch["input_ids"], attention_mask=batch["attention_mask"] ) text_embeddings = model.text_projection(text_features) # normalizing and computing dot similarity of poem and text embeddings poem_embeddings_n = F.normalize(poem_embeddings, p=2, dim=-1) text_embeddings_n = F.normalize(text_embeddings, p=2, dim=-1) dot_similarity = text_embeddings_n @ poem_embeddings_n.T # returning top n poems based on embedding similarity values, indices = torch.topk(dot_similarity.squeeze(0), len(poems)) # since we collected poems from many sources, some of them are equal (the same beyt with different meanings), # so we must check the poems added to result not to be duplicates def is_poem_duplicate(poem, poems): poem = regex.findall(r'\p{L}+', poem.replace('\u200c', '')) for other_poem in poems: other_poem = regex.findall(r'\p{L}+', other_poem.replace('\u200c', '')) if poem == other_poem: return True return False results = [] computed_k = 0 for i in range(len(poems)): if computed_k == n: break if not is_poem_duplicate(poems[indices[i]], [res['beyt'] for res in results]): results.append({ 'beyt': poems[indices[i]].replace(' * * ', ' * ').replace('*** * ', ''), 'similarity': values[i] }) computed_k += 1 if return_similarities: return results else: return [res['beyt'] for res in results] def predict_poems_from_image(model, poem_embeddings, image_filename, poems, n=10, return_similarities=False): """ Returns n poems which are the most similar to an image query Parameters: ----------- model: CLIPModel model to compute image query's embeddings poem_embeddings: sequence with shape (#poems, CFG.projection_dim) poem embeddings to check similarity image_filename: str path and file name for the image query poems: list of str poems corresponding to poem_embeddings n: int, optional number of poems to return return_similarities: bool, optional if True, a dictionary will be returned which has the poem beyts and their similarities to the text Returns: -------- A list of n poem strings whose embeddings are the most similar to image query's embedding. """ # Reading, Processing and applying transforms to image (all explained in datasets.py) image = cv2.imread(f"{image_filename}") image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) image = get_transforms(mode="test")(image=image)['image'] image = torch.tensor(image).permute(2, 0, 1).float() # getting image query's embeddings model.eval() with torch.no_grad(): image_features = model.image_encoder(torch.unsqueeze(image, 0).to(CFG.device)) image_embeddings = model.image_projection(image_features) # normalizing and computing dot similarity of poem and text embeddings poem_embeddings_n = F.normalize(poem_embeddings, p=2, dim=-1) image_embeddings_n = F.normalize(image_embeddings, p=2, dim=-1) dot_similarity = image_embeddings_n @ poem_embeddings_n.T # returning top n poems based on embedding similarity values, indices = torch.topk(dot_similarity.squeeze(0), len(poems)) # since we collected poems from many sources, some of them are equal (the same beyt with different meanings), # so we must check the poems added to result not to be duplicates def is_poem_duplicate(poem, poems): poem = regex.findall(r'\p{L}+', poem.replace('\u200c', '')) for other_poem in poems: other_poem = regex.findall(r'\p{L}+', other_poem.replace('\u200c', '')) if poem == other_poem: return True return False results = [] computed_k = 0 for i in range(len(poems)): if computed_k == n: break if not is_poem_duplicate(poems[indices[i]], [res['beyt'] for res in results]): results.append({ 'beyt': poems[indices[i]].replace(' * * ', ' * ').replace('*** * ', ''), 'similarity': values[i] }) computed_k += 1 if return_similarities: return results else: return [res['beyt'] for res in results] if __name__ == "__main__": """ Creates a PoemTextModel based on configs, and outputs some examples of its prediction. """ # get dataset from dataset_path (the same datasets as the train, val and test dataset files in the data directory is made) train_dataset, val_dataset, test_dataset = get_datasets() model = PoemTextModel(poem_encoder_pretrained=True, text_encoder_pretrained=True).to(CFG.device) model.eval() # Inference: Output some example predictions and write them in a file print("_"*20) print("Output Examples from test set") model, poem_embeddings = get_poem_embeddings(test_dataset, model) example = {} for i, test_data in enumerate(test_dataset[:100]): example[i] = {'Text': test_data["text"], 'True Beyt': test_data["beyt"], "Predicted Beyt":predict_poems_from_text(model, poem_embeddings, test_data["text"], [data['beyt'] for data in test_dataset], n=10)} for i in range(10): print("Text: ", example[i]['Text']) print("True Beyt: ", example[i]['True Beyt']) print("predicted Beyts: \n\t", "\n\t".join(example[i]["Predicted Beyt"])) with open('example_output__{}_{}.json'.format(CFG.poem_encoder_model, CFG.text_encoder_model),'w', encoding="utf-8") as f: f.write(json.dumps(example, ensure_ascii=False, indent= 4)) print("Preparing model for user input...") with open(CFG.dataset_path, encoding="utf-8") as f: dataset = json.load(f) model, poem_embeddings = get_poem_embeddings(dataset, model) while(True): user_text = input("Enter a Text to find poem beyts for: ") beyts = predict_poems_from_text(model, poem_embeddings, user_text, [data['beyt'] for data in dataset], n=10) print("predicted Beyts: \n\t", "\n\t".join(beyts)) with open('{}_output__{}_{}.json'.format(user_text, CFG.poem_encoder_model, CFG.text_encoder_model),'a+', encoding="utf-8") as f: f.write(json.dumps(beyts, ensure_ascii=False, indent= 4))