File size: 2,560 Bytes
2fa2727
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
from utils import get_datasets, build_loaders
from models import PoemTextModel
from train import train, test
from metrics import calc_metrics
from inference import predict_poems_from_text
from utils import get_poem_embeddings
import config as CFG
import json

def main():
    """
    Creates a PoemTextModel based on configs and trains, tests and outputs some examples of its prediction.
    """
    # get dataset from dataset_path (the same datasets as the train, val and test dataset files in the data directory is made)
    train_dataset, val_dataset, test_dataset = get_datasets()

    train_loader = build_loaders(train_dataset, mode="train")
    valid_loader = build_loaders(val_dataset, mode="valid")

    # train a PoemTextModel and write its loss history in a file
    model = PoemTextModel(poem_encoder_pretrained=True, text_encoder_pretrained=True).to(CFG.device)
    model, loss_history = train(model, train_loader, valid_loader)
    with open('loss_history_{}_{}.json'.format(CFG.poem_encoder_model, CFG.text_encoder_model),'w', encoding="utf-8") as f:
        f.write(json.dumps(loss_history, indent= 4))

    # compute accuracy, mean rank and MRR using test set and write them in a file
    model.eval()
    print("Accuracy on test set: ", test(model, test_dataset))
    metrics = calc_metrics(test_dataset, model)
    print('mean rank: ', metrics["mean_rank"])
    print('mean reciprocal rank (MRR)', metrics["mean_reciprocal_rank_(MRR)"])
    with open('test_metrics_{}_{}.json'.format(CFG.poem_encoder_model, CFG.text_encoder_model),'w', encoding="utf-8") as f:
        f.write(json.dumps(metrics, indent= 4))
    
    # Inference: Output some example predictions and write them in a file
    print("_"*20)
    print("Output Examples from test set")
    model, poem_embeddings = get_poem_embeddings(test_dataset, model)
    example = {}
    for i, test_data in enumerate(test_dataset[:100]):
        example[i] = {'Text': test_data["text"], 'True Beyt': test_data["beyt"], "Predicted Beyt":predict_poems_from_text(model, poem_embeddings, test_data["text"], [data['beyt'] for data in test_dataset], n=10)}
    for i in range(10):
        print("Text: ", example[i]['Text'])
        print("True Beyt: ", example[i]['True Beyt'])
        print("predicted Beyts: \n\t", "\n\t".join(example[i]["Predicted Beyt"]))
    with open('example_output__{}_{}.json'.format(CFG.poem_encoder_model, CFG.text_encoder_model),'w', encoding="utf-8") as f:
        f.write(json.dumps(example, ensure_ascii=False, indent= 4))

if __name__ == "__main__":
    main()