File size: 2,862 Bytes
1bc9b9d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
from utils import get_datasets, build_loaders
from models import PoemTextModel
from train import train, test
from metrics import calc_metrics
from inference import predict_poems_from_text
from utils import get_poem_embeddings
import config as CFG
import json

def main():
    """
    Creates a PoemTextModel based on configs and trains, tests and outputs some examples of its prediction.
    """
    train_or_not = input("Train a new CLIP model using text embeddings? (needs the sajjadayobi360/cc3mfav2 and adityajn105/flickr8k datasets to be downloaded)\n[Y/N]")
    if train_or_not == 'Y':
        # Please download sajjadayobi360/cc3mfav2 and adityajn105/flickr8k datasets from kaggle
        #   !kaggle datasets download -d sajjadayobi360/cc3mfav2
        #   !kaggle datasets download -d adityajn105/flickr8k
        #.... TODO
        clip_dataset_dict = []
        # get dataset from dataset_path (the same datasets as the train, val and test dataset files in the data directory is made)
        train_dataset, val_dataset, test_dataset = get_clip_datasets(clip_dataset_dict)

        train_loader = build_image_loaders(train_dataset, mode="train")
        valid_loader = build_image_loaders(val_dataset, mode="valid")

        # train a PoemTextModel and write its loss history in a file
        model = CLIPModel(image_encoder_pretrained=True, 
                    text_encoder_pretrained=True, 
                    text_projection_trainable=False,
                    is_image_poem_pair=False
                    ).to(CFG.device)
        model, loss_history = train(model, train_loader, valid_loader)
        with open('loss_history_{}_{}.json'.format(CFG.poem_encoder_model, CFG.text_encoder_model),'w', encoding="utf-8") as f:
            f.write(json.dumps(loss_history, indent= 4))

    # Inference: Get a filename and output predictions then write them in a file
    print("_"*20)
    print("INFERENCE PHASE")
    model = CLIPModel(image_encoder_pretrained=True, 
                text_encoder_pretrained=True, 
                text_projection_trainable=False,
                is_image_poem_pair=True
                ).to(CFG.device)
    model.eval()
    with open(CFG.dataset_path, encoding="utf-8") as f:
        dataset = json.load(f)

    model, poem_embeddings = get_poem_embeddings(test_dataset, model)
        
    while(True):
        image_filename = input("Enter an image filename to predict poems for")
        beyts = predict_poems_from_image(model, poem_embeddings, image_filename, [data['beyt'] for data in dataset], n=10)
        print("predicted Beyts: \n\t", "\n\t".join(beyts))
        with open('{}_output__{}_{}.json'.format(image_filename, CFG.poem_encoder_model, CFG.text_encoder_model),'a+', encoding="utf-8") as f:
            f.write(json.dumps(beyts, ensure_ascii=False, indent= 4))

if __name__ == "__main__":
    main()