from __future__ import annotations import numpy as np import inference from utils import get_poem_embeddings import config as CFG #for running this script as main from utils import get_datasets, build_loaders from models import PoemTextModel from train import train, test import json import os def calc_metrics(test_dataset, model): """ compute ranks of the test_dataset (and mean rank and MRR) Parameters: ----------- test_dataset: list of dict dataset containing text and poem beyts to compute metrics from model: PoemTextModel The PoemTextModel model to get poem embeddings from and predict poems for each text """ # computing all poems embeddings once (to avoid computing them for each test text) m , embedding = get_poem_embeddings(test_dataset, model) # adding poems and texts poems = [] meanings = [] for p in np.array(test_dataset): poems.append(p['beyt']) meanings.append(p['text']) # instantiating a text tokenizer to encode texts text_tokenizer = CFG.tokenizers[CFG.text_encoder_model].from_pretrained(CFG.text_tokenizer) rank = [] for i, meaning in enumerate(meanings): # predict most similar poem beyts for each text sorted_pred = inference.predict_poems_from_text(model, embedding, meaning, poems, text_tokenizer, n=len(test_dataset)) # find index of this text's true beyt in the sorted predictions idx = sorted_pred.index(poems[i]) rank.append(idx+1) rank = np.array(rank) metrics = { "mean_rank": np.mean(rank), "mean_reciprocal_rank_(MRR)":np.mean(np.reciprocal(rank.astype(float))), "rank": rank.tolist() } return metrics if __name__ == "__main__": """ Creates a PoemTextModel based on configs, and computes its metrics. """ # get dataset from dataset_path (the same datasets as the train, val and test dataset files in the data directory is made) train_dataset, val_dataset, test_dataset = get_datasets() model = PoemTextModel(poem_encoder_pretrained=True, text_encoder_pretrained=True).to(CFG.device) model.eval() # compute accuracy, mean rank and MRR using test set and write them in a file print("Accuracy on test set: ", test(model, test_dataset)) metrics = calc_metrics(test_dataset, model) print('mean rank: ', metrics["mean_rank"]) print('mean reciprocal rank (MRR)', metrics["mean_reciprocal_rank_(MRR)"]) with open('test_metrics_{}_{}.json'.format(CFG.poem_encoder_model, CFG.text_encoder_model),'w', encoding="utf-8") as f: f.write(json.dumps(metrics, indent= 4))