File size: 3,142 Bytes
bbc5e76
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
# -*- coding: utf-8 -*-
"""
@author: Van Duc <vvduc03@gmail.com>
"""
"""Import necessary packages"""
import os
import torch
import random
import matplotlib.pyplot as plt

def read_caption(num_caption, vocab):
    """
    Convert caption form number to string
    Args:
        num_caption: caption form number
        vocab: vocabulary file
    Returns:
        A list of string (ex: [a, dog, in, the, sky])
    """
    str_caption = []
    for cap in num_caption[1:]:
        if vocab.itos[cap.item()] == "<EOS>":
            break
        str_caption.append(cap)

    return [vocab.itos[id.item()] for id in str_caption]

def plot_examples(model, device, dataset, vocab, num_examples=20):
    """
    Plot image, correct caption and predict caption of some image in dataset

    Args:
        model: pretrained-model to predict caption
        device: target device cpu and gpu
        dataset: dataset
        vocab: vocabulary
        num_examples: number examples plot

    Returns:
        Images of picture and caption
    """
    model.eval()
    model.to(device)

    # Load over examples
    for example in range(num_examples):
        # Take some example from dataset
        image, caption = dataset.__getitem__(random.randint(0, dataset.__len__()))
        image = image.to(device)

        # Print output
        correct = f"Example {example+1} CORRECT: " + " ".join(read_caption(caption, vocab))
        output = f"Example {example+1} OUTPUT: " + " ".join(model.caption_image(image.unsqueeze(0), vocab))
        print(correct)
        print(output)
        print('----------------------------------------------')

        # Plot image and caption
        fig, ax = plt.subplots()
        ax.imshow(dataset.img)
        ax.axis('off')
        fig.text(0.5, 0.05,
                 correct + '\n' + output,
                 ha="center")

        plt.show()

    model.train()


def save_checkpoint(model, optimizer, epoch, save_path, last_loss, best_loss):
    print("=> Saving checkpoint")
    checkpoint = {
        "epoch": epoch + 1,
        "model": model.state_dict(),
        "optimizer": optimizer.state_dict()
    }

    torch.save(checkpoint, os.path.join(save_path, "last.pt"))
    if last_loss < best_loss:
        best_loss = last_loss
        torch.save(checkpoint, os.path.join(save_path, "best.pt"))

    return best_loss

def load_check_point_to_use(checkpoint_file, model, device):
    print("=> Loading checkpoint")
    checkpoint = torch.load(checkpoint_file, map_location=device)
    model.load_state_dict(checkpoint["model"])

    return model

def load_checkpoint_to_continue(checkpoint_file, model, optimizer, lr, device):
    print("=> Loading checkpoint")
    checkpoint = torch.load(checkpoint_file+'/last.pt', map_location=device)
    model.load_state_dict(checkpoint["model"])
    optimizer.load_state_dict(checkpoint["optimizer"])
    epoch = checkpoint["epoch"]

    # If we don't do this then it will just have learning rate of old checkpoint
    # and it will lead to many hours of debugging \:
    for param_group in optimizer.param_groups:
        param_group["lr"] = lr

    return model, epoch