Spaces:
Runtime error
Runtime error
Upload 15 files
Browse files- app.py +86 -0
- dataset.py +191 -0
- examples/000000000139.jpg +0 -0
- examples/000000000785.jpg +0 -0
- examples/000000005477.jpg +0 -0
- examples/good1.png +0 -0
- examples/good2.png +0 -0
- examples/good3.png +0 -0
- examples/good6.png +0 -0
- model.py +86 -0
- requirement.txt +4 -0
- savedir/best.pt +3 -0
- savedir/last.pt +3 -0
- utils.py +105 -0
- vocab.json +0 -0
app.py
ADDED
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
"""
|
3 |
+
@author: Van Duc <vvduc03@gmail.com>
|
4 |
+
"""
|
5 |
+
"""Import necessary packages"""
|
6 |
+
import os
|
7 |
+
import argparse
|
8 |
+
import config
|
9 |
+
import gradio as gr
|
10 |
+
|
11 |
+
from model import ImgCaption_Model
|
12 |
+
from dataset import Vocabulary
|
13 |
+
from timeit import default_timer as timer
|
14 |
+
from utils import load_check_point_to_use
|
15 |
+
|
16 |
+
# Initialize parameters and parse the parameters
|
17 |
+
def get_args():
|
18 |
+
parse = argparse.ArgumentParser()
|
19 |
+
parse.add_argument('--save-path', '-s', type=str, default=config.save_path, help='number of batch size')
|
20 |
+
parse.add_argument('--transform', default=config.transform, help='Compose transform of images')
|
21 |
+
parse.add_argument('--embed-size', default=config.embed_size, help='Size of embedding')
|
22 |
+
parse.add_argument('--hidden-size', default=config.hidden_size, help='Number of hidden nodes in RNN')
|
23 |
+
parse.add_argument('--num-layer', default=config.num_layer, help='Number of layers lstm stack')
|
24 |
+
parse.add_argument('--num-workers', default=config.num_workers, help='Number of core CPU use to load data')
|
25 |
+
args = parse.parse_args()
|
26 |
+
return args
|
27 |
+
|
28 |
+
# Load vocab file
|
29 |
+
vocab = Vocabulary()
|
30 |
+
vocab.read_vocab()
|
31 |
+
|
32 |
+
# Load arguments
|
33 |
+
args = get_args()
|
34 |
+
|
35 |
+
# Load model
|
36 |
+
model = ImgCaption_Model(args.embed_size, args.hidden_size, len(vocab), args.num_layer)
|
37 |
+
|
38 |
+
# Load saved weights
|
39 |
+
load_check_point_to_use(args.save_path + '/best.pt', model, 'cpu')
|
40 |
+
|
41 |
+
def caption(img):
|
42 |
+
"""Transforms, describe about image and returns caption and time taken.
|
43 |
+
"""
|
44 |
+
# Start the timer
|
45 |
+
start_time = timer()
|
46 |
+
|
47 |
+
# Transform the target image
|
48 |
+
img = args.transform(img)
|
49 |
+
|
50 |
+
# Put model into evaluation mode and describe image
|
51 |
+
model.eval()
|
52 |
+
prompt = " ".join(model.caption_image(img.unsqueeze(0), vocab))
|
53 |
+
|
54 |
+
# Calculate the prediction time
|
55 |
+
pred_time = round(timer() - start_time, 5)
|
56 |
+
|
57 |
+
# Return the caption and prediction time
|
58 |
+
return prompt, pred_time
|
59 |
+
|
60 |
+
|
61 |
+
# Create title, description and article strings
|
62 |
+
def main():
|
63 |
+
title = "Image Captioning 🖼➡️🆎"
|
64 |
+
description = "A model describe about the picture"
|
65 |
+
article = "Created on [GITHUB](https://github.com/vvduc1803/Image-Captioning)."
|
66 |
+
|
67 |
+
# Create examples list from "examples/" directory
|
68 |
+
example_list = [["examples/" + example] for example in os.listdir("examples")]
|
69 |
+
|
70 |
+
# Create the Gradio demo
|
71 |
+
demo = gr.Interface(fn=caption, # mapping function from input to output
|
72 |
+
inputs=gr.Image(type="pil"), # what are the inputs?
|
73 |
+
outputs=[gr.Textbox(label="Caption"), # what are the outputs?
|
74 |
+
gr.Number(label="Prediction time (s)")],
|
75 |
+
# our fn has two outputs, therefore we have two outputs
|
76 |
+
# Create examples list from "examples/" directory
|
77 |
+
examples=example_list,
|
78 |
+
title=title,
|
79 |
+
description=description,
|
80 |
+
article=article)
|
81 |
+
|
82 |
+
# Launch the demo!
|
83 |
+
demo.launch(server_name="127.0.0.1", server_port=1234, share=True)
|
84 |
+
|
85 |
+
if __name__ == '__main__':
|
86 |
+
main()
|
dataset.py
ADDED
@@ -0,0 +1,191 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
"""
|
3 |
+
@author: Van Duc <vvduc03@gmail.com>
|
4 |
+
"""
|
5 |
+
"""Import necessary packages"""
|
6 |
+
import os
|
7 |
+
import spacy # for tokenizer
|
8 |
+
import torch
|
9 |
+
import config
|
10 |
+
import json
|
11 |
+
|
12 |
+
from torch.nn.utils.rnn import pad_sequence # pad batch
|
13 |
+
from torch.utils.data import DataLoader, Dataset
|
14 |
+
from PIL import Image
|
15 |
+
import torchvision.transforms as transforms
|
16 |
+
|
17 |
+
# Download with: python -m spacy download en_core_web_sm
|
18 |
+
spacy_eng = spacy.load("en_core_web_sm")
|
19 |
+
|
20 |
+
class Vocabulary:
|
21 |
+
def __init__(self, freq_threshold=5):
|
22 |
+
# Initialize 2 dictionary: index to string and string to index
|
23 |
+
self.itos = {0: "<PAD>", 1: "<SOS>", 2: "<EOS>", 3: "<UNK>"}
|
24 |
+
self.stoi = {"<PAD>": 0, "<SOS>": 1, "<EOS>": 2, "<UNK>": 3}
|
25 |
+
|
26 |
+
# Threshold for add word to dictionary
|
27 |
+
self.freq_threshold = freq_threshold
|
28 |
+
|
29 |
+
def __len__(self):
|
30 |
+
return len(self.itos)
|
31 |
+
|
32 |
+
@staticmethod
|
33 |
+
def tokenizer_eng(text):
|
34 |
+
return [tok.text.lower() for tok in spacy_eng.tokenizer(text)]
|
35 |
+
|
36 |
+
def build_vocabulary(self, sentence_list):
|
37 |
+
frequencies = {}
|
38 |
+
idx = 4
|
39 |
+
|
40 |
+
for sentence in sentence_list:
|
41 |
+
for word in self.tokenizer_eng(sentence):
|
42 |
+
if word not in frequencies:
|
43 |
+
frequencies[word] = 1
|
44 |
+
|
45 |
+
else:
|
46 |
+
frequencies[word] += 1
|
47 |
+
|
48 |
+
if frequencies[word] == self.freq_threshold:
|
49 |
+
self.stoi[word] = idx
|
50 |
+
self.itos[idx] = word
|
51 |
+
idx += 1
|
52 |
+
|
53 |
+
def read_vocab(self, file_name='vocab.json'):
|
54 |
+
"""
|
55 |
+
Load created vocabulary file and replace the 'index to string' and 'string to index' dictionary
|
56 |
+
"""
|
57 |
+
vocab_path = open(file_name, 'r')
|
58 |
+
vocab = json.load(vocab_path)
|
59 |
+
new_itos = {int(key): value for key, value in vocab['itos'].items()}
|
60 |
+
|
61 |
+
self.itos = new_itos
|
62 |
+
self.stoi = vocab['stoi']
|
63 |
+
|
64 |
+
def create_vocab(self, file_name='vocab.json'):
|
65 |
+
# create json object from dictionary
|
66 |
+
vocab = json.dumps({'itos': self.itos,
|
67 |
+
'stoi': self.stoi})
|
68 |
+
|
69 |
+
# open file for writing, "w"
|
70 |
+
f = open(file_name, "w")
|
71 |
+
|
72 |
+
# write json object to file
|
73 |
+
f.write(vocab)
|
74 |
+
|
75 |
+
# close file
|
76 |
+
f.close()
|
77 |
+
|
78 |
+
def numericalize(self, text):
|
79 |
+
tokenized_text = self.tokenizer_eng(text)
|
80 |
+
|
81 |
+
return [
|
82 |
+
self.stoi[token] if token in self.stoi else self.stoi["<UNK>"]
|
83 |
+
for token in tokenized_text
|
84 |
+
]
|
85 |
+
|
86 |
+
class CoCoDataset(Dataset):
|
87 |
+
def __init__(self, root_dir, transform=None, freq_threshold=5):
|
88 |
+
self.root_dir = root_dir
|
89 |
+
self.freq_threshold = freq_threshold
|
90 |
+
captions_path = open(os.path.join(self.root_dir, config.captions), 'r')
|
91 |
+
captions_file = json.load(captions_path)
|
92 |
+
self.transform = transform
|
93 |
+
|
94 |
+
# Get img, caption columns
|
95 |
+
self.imageID_list = [captions['image_id'] for captions in captions_file['annotations']]
|
96 |
+
self.captions_list = [captions['caption'] for captions in captions_file['annotations']]
|
97 |
+
|
98 |
+
# # Initialize vocabulary and build vocab
|
99 |
+
# if not self.set_vocab:
|
100 |
+
# self.vocab = Vocabulary(self.freq_threshold)
|
101 |
+
# self.vocab.build_vocabulary(self.captions_list)
|
102 |
+
# self.vocab.create_vocab()
|
103 |
+
# else:
|
104 |
+
# self.vocab = self.set_vocab
|
105 |
+
|
106 |
+
# Load vocab file
|
107 |
+
self.vocab = Vocabulary(self.freq_threshold)
|
108 |
+
self.vocab.read_vocab()
|
109 |
+
|
110 |
+
def __len__(self):
|
111 |
+
return len(self.imageID_list)
|
112 |
+
|
113 |
+
def __getitem__(self, index):
|
114 |
+
|
115 |
+
# Load index caption and image
|
116 |
+
caption = self.captions_list[index]
|
117 |
+
img_id = str((self.imageID_list[index])).zfill(12) + '.jpg'
|
118 |
+
self.img = Image.open(os.path.join(self.root_dir, config.images, img_id)).convert("RGB")
|
119 |
+
|
120 |
+
# Transform image
|
121 |
+
if self.transform:
|
122 |
+
img = self.transform(self.img)
|
123 |
+
|
124 |
+
# Numericalized captions
|
125 |
+
numericalized_caption = [self.vocab.stoi["<SOS>"]]
|
126 |
+
numericalized_caption += self.vocab.numericalize(caption)
|
127 |
+
numericalized_caption.append(self.vocab.stoi["<EOS>"])
|
128 |
+
|
129 |
+
return img, torch.tensor(numericalized_caption)
|
130 |
+
|
131 |
+
class MyCollate:
|
132 |
+
def __init__(self, pad_idx):
|
133 |
+
self.pad_idx = pad_idx
|
134 |
+
|
135 |
+
def __call__(self, batch):
|
136 |
+
imgs = [item[0].unsqueeze(0) for item in batch]
|
137 |
+
imgs = torch.cat(imgs, dim=0)
|
138 |
+
targets = [item[1] for item in batch]
|
139 |
+
targets = pad_sequence(targets, batch_first=False, padding_value=self.pad_idx)
|
140 |
+
|
141 |
+
return imgs, targets
|
142 |
+
|
143 |
+
|
144 |
+
def get_loader(
|
145 |
+
root_folder,
|
146 |
+
transform,
|
147 |
+
batch_size=16,
|
148 |
+
num_workers=4,
|
149 |
+
shuffle=True,
|
150 |
+
pin_memory=True
|
151 |
+
):
|
152 |
+
dataset = CoCoDataset(root_folder, transform=transform)
|
153 |
+
|
154 |
+
pad_idx = dataset.vocab.stoi["<PAD>"]
|
155 |
+
|
156 |
+
loader = DataLoader(
|
157 |
+
dataset=dataset,
|
158 |
+
batch_size=batch_size,
|
159 |
+
num_workers=num_workers,
|
160 |
+
shuffle=shuffle,
|
161 |
+
pin_memory=pin_memory,
|
162 |
+
collate_fn=MyCollate(pad_idx=pad_idx),
|
163 |
+
)
|
164 |
+
return dataset, loader
|
165 |
+
|
166 |
+
|
167 |
+
|
168 |
+
if __name__ == "__main__":
|
169 |
+
transform = transforms.Compose(
|
170 |
+
[transforms.Resize((224, 224)), transforms.ToTensor(),]
|
171 |
+
)
|
172 |
+
|
173 |
+
train_dataset, train_loader = get_loader(root_folder=config.train,
|
174 |
+
transform=config.transform,
|
175 |
+
batch_size=config.batch_size,
|
176 |
+
num_workers=config.num_workers,
|
177 |
+
shuffle=True)
|
178 |
+
from utils import plot_examples
|
179 |
+
from model import ImgCaption_Model
|
180 |
+
model = ImgCaption_Model(256, 256, len(train_dataset.vocab), 1)
|
181 |
+
plot_examples(model, 'cuda', train_dataset, train_dataset.vocab)
|
182 |
+
# imgs, captions = dataset.__getitem__(1)
|
183 |
+
# print(imgs.shape)
|
184 |
+
# print(captions)
|
185 |
+
# print(captions.shape)
|
186 |
+
# for x, y in loader:
|
187 |
+
# a = [[1], [2], [3]]
|
188 |
+
# print(a[:-1])
|
189 |
+
# print(y[:-1])
|
190 |
+
# print(y)
|
191 |
+
# break
|
examples/000000000139.jpg
ADDED
![]() |
examples/000000000785.jpg
ADDED
![]() |
examples/000000005477.jpg
ADDED
![]() |
examples/good1.png
ADDED
![]() |
examples/good2.png
ADDED
![]() |
examples/good3.png
ADDED
![]() |
examples/good6.png
ADDED
![]() |
model.py
ADDED
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
"""
|
3 |
+
@author: Van Duc <vvduc03@gmail.com>
|
4 |
+
"""
|
5 |
+
"""Import necessary packages"""
|
6 |
+
import torch
|
7 |
+
import torch.nn as nn
|
8 |
+
import torchvision.models as models
|
9 |
+
|
10 |
+
from torchinfo import summary
|
11 |
+
|
12 |
+
class CNN(nn.Module):
|
13 |
+
def __init__(self, embed_size=256, train_model=False):
|
14 |
+
super().__init__()
|
15 |
+
|
16 |
+
# Load pretrained Efficientnet-B2 model
|
17 |
+
self.model = models.efficientnet_b2(weights=models.EfficientNet_B2_Weights)
|
18 |
+
|
19 |
+
# Frozen all layer of model
|
20 |
+
if not train_model:
|
21 |
+
for param in self.model.parameters():
|
22 |
+
param.requires_grad = False
|
23 |
+
|
24 |
+
# Replace head of model
|
25 |
+
self.model.classifier.requires_grad_(True)
|
26 |
+
self.model.classifier = nn.Sequential(nn.Linear(1408, embed_size),
|
27 |
+
nn.ReLU(),
|
28 |
+
nn.Dropout(0.5))
|
29 |
+
|
30 |
+
def forward(self, x):
|
31 |
+
return self.model(x)
|
32 |
+
|
33 |
+
class RNN(nn.Module):
|
34 |
+
def __init__(self, hidden_size, vocab_size, num_layers, embed_size=256):
|
35 |
+
super().__init__()
|
36 |
+
# Embedding caption
|
37 |
+
self.embed = nn.Embedding(vocab_size, embed_size)
|
38 |
+
|
39 |
+
# Initialize some necessary layer
|
40 |
+
self.lstm = nn.LSTM(embed_size, hidden_size, num_layers)
|
41 |
+
self.linear = nn.Linear(hidden_size, vocab_size)
|
42 |
+
self.drop_out = nn.Dropout(0.5)
|
43 |
+
|
44 |
+
def forward(self, features, captions):
|
45 |
+
embeddings = self.drop_out(self.embed(captions))
|
46 |
+
embeddings = torch.cat((features.unsqueeze(0), embeddings), dim=0)
|
47 |
+
hidden, _ = self.lstm(embeddings)
|
48 |
+
outputs = self.linear(hidden)
|
49 |
+
|
50 |
+
return outputs
|
51 |
+
|
52 |
+
class ImgCaption_Model(nn.Module):
|
53 |
+
def __init__(self, embed_size, hidden_size, vocab_size, num_layers):
|
54 |
+
super().__init__()
|
55 |
+
self.CNN = CNN(embed_size)
|
56 |
+
self.RNN = RNN(hidden_size, vocab_size, num_layers, embed_size)
|
57 |
+
|
58 |
+
def forward(self, images, captions):
|
59 |
+
|
60 |
+
features = self.CNN(images)
|
61 |
+
outputs = self.RNN(features, captions)
|
62 |
+
|
63 |
+
return outputs
|
64 |
+
|
65 |
+
def caption_image(self, image, vocab, max_length=50):
|
66 |
+
result = []
|
67 |
+
|
68 |
+
with torch.inference_mode():
|
69 |
+
features = self.CNN(image)
|
70 |
+
state = None
|
71 |
+
for _ in range(max_length):
|
72 |
+
|
73 |
+
hidden, state = self.RNN.lstm(features, state)
|
74 |
+
output = self.RNN.linear(hidden)
|
75 |
+
predict = output.argmax(axis=1)
|
76 |
+
|
77 |
+
if vocab.itos[predict.item()] == "<EOS>":
|
78 |
+
break
|
79 |
+
|
80 |
+
result.append(predict.item())
|
81 |
+
features = self.RNN.embed(predict)
|
82 |
+
|
83 |
+
return [vocab.itos[idx] for idx in result[1:]]
|
84 |
+
|
85 |
+
if __name__ == '__main__':
|
86 |
+
pass
|
requirement.txt
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
gradio
|
2 |
+
torch
|
3 |
+
spacy
|
4 |
+
torchvision
|
savedir/best.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:cdac83fb7cfc485259434762661b38b749cfca2df720fd583e36bac929a3a968
|
3 |
+
size 104784308
|
savedir/last.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:bdef3c9612113193c97196cebd5ef3b9115aceb3ba60b572e8897217b3c973cf
|
3 |
+
size 104784308
|
utils.py
ADDED
@@ -0,0 +1,105 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
"""
|
3 |
+
@author: Van Duc <vvduc03@gmail.com>
|
4 |
+
"""
|
5 |
+
"""Import necessary packages"""
|
6 |
+
import os
|
7 |
+
import torch
|
8 |
+
import random
|
9 |
+
import matplotlib.pyplot as plt
|
10 |
+
|
11 |
+
def read_caption(num_caption, vocab):
|
12 |
+
"""
|
13 |
+
Convert caption form number to string
|
14 |
+
Args:
|
15 |
+
num_caption: caption form number
|
16 |
+
vocab: vocabulary file
|
17 |
+
Returns:
|
18 |
+
A list of string (ex: [a, dog, in, the, sky])
|
19 |
+
"""
|
20 |
+
str_caption = []
|
21 |
+
for cap in num_caption[1:]:
|
22 |
+
if vocab.itos[cap.item()] == "<EOS>":
|
23 |
+
break
|
24 |
+
str_caption.append(cap)
|
25 |
+
|
26 |
+
return [vocab.itos[id.item()] for id in str_caption]
|
27 |
+
|
28 |
+
def plot_examples(model, device, dataset, vocab, num_examples=20):
|
29 |
+
"""
|
30 |
+
Plot image, correct caption and predict caption of some image in dataset
|
31 |
+
|
32 |
+
Args:
|
33 |
+
model: pretrained-model to predict caption
|
34 |
+
device: target device cpu and gpu
|
35 |
+
dataset: dataset
|
36 |
+
vocab: vocabulary
|
37 |
+
num_examples: number examples plot
|
38 |
+
|
39 |
+
Returns:
|
40 |
+
Images of picture and caption
|
41 |
+
"""
|
42 |
+
model.eval()
|
43 |
+
model.to(device)
|
44 |
+
|
45 |
+
# Load over examples
|
46 |
+
for example in range(num_examples):
|
47 |
+
# Take some example from dataset
|
48 |
+
image, caption = dataset.__getitem__(random.randint(0, dataset.__len__()))
|
49 |
+
image = image.to(device)
|
50 |
+
|
51 |
+
# Print output
|
52 |
+
correct = f"Example {example+1} CORRECT: " + " ".join(read_caption(caption, vocab))
|
53 |
+
output = f"Example {example+1} OUTPUT: " + " ".join(model.caption_image(image.unsqueeze(0), vocab))
|
54 |
+
print(correct)
|
55 |
+
print(output)
|
56 |
+
print('----------------------------------------------')
|
57 |
+
|
58 |
+
# Plot image and caption
|
59 |
+
fig, ax = plt.subplots()
|
60 |
+
ax.imshow(dataset.img)
|
61 |
+
ax.axis('off')
|
62 |
+
fig.text(0.5, 0.05,
|
63 |
+
correct + '\n' + output,
|
64 |
+
ha="center")
|
65 |
+
|
66 |
+
plt.show()
|
67 |
+
|
68 |
+
model.train()
|
69 |
+
|
70 |
+
|
71 |
+
def save_checkpoint(model, optimizer, epoch, save_path, last_loss, best_loss):
|
72 |
+
print("=> Saving checkpoint")
|
73 |
+
checkpoint = {
|
74 |
+
"epoch": epoch + 1,
|
75 |
+
"model": model.state_dict(),
|
76 |
+
"optimizer": optimizer.state_dict()
|
77 |
+
}
|
78 |
+
|
79 |
+
torch.save(checkpoint, os.path.join(save_path, "last.pt"))
|
80 |
+
if last_loss < best_loss:
|
81 |
+
best_loss = last_loss
|
82 |
+
torch.save(checkpoint, os.path.join(save_path, "best.pt"))
|
83 |
+
|
84 |
+
return best_loss
|
85 |
+
|
86 |
+
def load_check_point_to_use(checkpoint_file, model, device):
|
87 |
+
print("=> Loading checkpoint")
|
88 |
+
checkpoint = torch.load(checkpoint_file, map_location=device)
|
89 |
+
model.load_state_dict(checkpoint["model"])
|
90 |
+
|
91 |
+
return model
|
92 |
+
|
93 |
+
def load_checkpoint_to_continue(checkpoint_file, model, optimizer, lr, device):
|
94 |
+
print("=> Loading checkpoint")
|
95 |
+
checkpoint = torch.load(checkpoint_file+'/last.pt', map_location=device)
|
96 |
+
model.load_state_dict(checkpoint["model"])
|
97 |
+
optimizer.load_state_dict(checkpoint["optimizer"])
|
98 |
+
epoch = checkpoint["epoch"]
|
99 |
+
|
100 |
+
# If we don't do this then it will just have learning rate of old checkpoint
|
101 |
+
# and it will lead to many hours of debugging \:
|
102 |
+
for param_group in optimizer.param_groups:
|
103 |
+
param_group["lr"] = lr
|
104 |
+
|
105 |
+
return model, epoch
|
vocab.json
ADDED
The diff for this file is too large to render.
See raw diff
|
|