vvd2003 commited on
Commit
bbc5e76
1 Parent(s): c78dd05

Upload 15 files

Browse files
app.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """
3
+ @author: Van Duc <vvduc03@gmail.com>
4
+ """
5
+ """Import necessary packages"""
6
+ import os
7
+ import argparse
8
+ import config
9
+ import gradio as gr
10
+
11
+ from model import ImgCaption_Model
12
+ from dataset import Vocabulary
13
+ from timeit import default_timer as timer
14
+ from utils import load_check_point_to_use
15
+
16
+ # Initialize parameters and parse the parameters
17
+ def get_args():
18
+ parse = argparse.ArgumentParser()
19
+ parse.add_argument('--save-path', '-s', type=str, default=config.save_path, help='number of batch size')
20
+ parse.add_argument('--transform', default=config.transform, help='Compose transform of images')
21
+ parse.add_argument('--embed-size', default=config.embed_size, help='Size of embedding')
22
+ parse.add_argument('--hidden-size', default=config.hidden_size, help='Number of hidden nodes in RNN')
23
+ parse.add_argument('--num-layer', default=config.num_layer, help='Number of layers lstm stack')
24
+ parse.add_argument('--num-workers', default=config.num_workers, help='Number of core CPU use to load data')
25
+ args = parse.parse_args()
26
+ return args
27
+
28
+ # Load vocab file
29
+ vocab = Vocabulary()
30
+ vocab.read_vocab()
31
+
32
+ # Load arguments
33
+ args = get_args()
34
+
35
+ # Load model
36
+ model = ImgCaption_Model(args.embed_size, args.hidden_size, len(vocab), args.num_layer)
37
+
38
+ # Load saved weights
39
+ load_check_point_to_use(args.save_path + '/best.pt', model, 'cpu')
40
+
41
+ def caption(img):
42
+ """Transforms, describe about image and returns caption and time taken.
43
+ """
44
+ # Start the timer
45
+ start_time = timer()
46
+
47
+ # Transform the target image
48
+ img = args.transform(img)
49
+
50
+ # Put model into evaluation mode and describe image
51
+ model.eval()
52
+ prompt = " ".join(model.caption_image(img.unsqueeze(0), vocab))
53
+
54
+ # Calculate the prediction time
55
+ pred_time = round(timer() - start_time, 5)
56
+
57
+ # Return the caption and prediction time
58
+ return prompt, pred_time
59
+
60
+
61
+ # Create title, description and article strings
62
+ def main():
63
+ title = "Image Captioning 🖼➡️🆎"
64
+ description = "A model describe about the picture"
65
+ article = "Created on [GITHUB](https://github.com/vvduc1803/Image-Captioning)."
66
+
67
+ # Create examples list from "examples/" directory
68
+ example_list = [["examples/" + example] for example in os.listdir("examples")]
69
+
70
+ # Create the Gradio demo
71
+ demo = gr.Interface(fn=caption, # mapping function from input to output
72
+ inputs=gr.Image(type="pil"), # what are the inputs?
73
+ outputs=[gr.Textbox(label="Caption"), # what are the outputs?
74
+ gr.Number(label="Prediction time (s)")],
75
+ # our fn has two outputs, therefore we have two outputs
76
+ # Create examples list from "examples/" directory
77
+ examples=example_list,
78
+ title=title,
79
+ description=description,
80
+ article=article)
81
+
82
+ # Launch the demo!
83
+ demo.launch(server_name="127.0.0.1", server_port=1234, share=True)
84
+
85
+ if __name__ == '__main__':
86
+ main()
dataset.py ADDED
@@ -0,0 +1,191 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """
3
+ @author: Van Duc <vvduc03@gmail.com>
4
+ """
5
+ """Import necessary packages"""
6
+ import os
7
+ import spacy # for tokenizer
8
+ import torch
9
+ import config
10
+ import json
11
+
12
+ from torch.nn.utils.rnn import pad_sequence # pad batch
13
+ from torch.utils.data import DataLoader, Dataset
14
+ from PIL import Image
15
+ import torchvision.transforms as transforms
16
+
17
+ # Download with: python -m spacy download en_core_web_sm
18
+ spacy_eng = spacy.load("en_core_web_sm")
19
+
20
+ class Vocabulary:
21
+ def __init__(self, freq_threshold=5):
22
+ # Initialize 2 dictionary: index to string and string to index
23
+ self.itos = {0: "<PAD>", 1: "<SOS>", 2: "<EOS>", 3: "<UNK>"}
24
+ self.stoi = {"<PAD>": 0, "<SOS>": 1, "<EOS>": 2, "<UNK>": 3}
25
+
26
+ # Threshold for add word to dictionary
27
+ self.freq_threshold = freq_threshold
28
+
29
+ def __len__(self):
30
+ return len(self.itos)
31
+
32
+ @staticmethod
33
+ def tokenizer_eng(text):
34
+ return [tok.text.lower() for tok in spacy_eng.tokenizer(text)]
35
+
36
+ def build_vocabulary(self, sentence_list):
37
+ frequencies = {}
38
+ idx = 4
39
+
40
+ for sentence in sentence_list:
41
+ for word in self.tokenizer_eng(sentence):
42
+ if word not in frequencies:
43
+ frequencies[word] = 1
44
+
45
+ else:
46
+ frequencies[word] += 1
47
+
48
+ if frequencies[word] == self.freq_threshold:
49
+ self.stoi[word] = idx
50
+ self.itos[idx] = word
51
+ idx += 1
52
+
53
+ def read_vocab(self, file_name='vocab.json'):
54
+ """
55
+ Load created vocabulary file and replace the 'index to string' and 'string to index' dictionary
56
+ """
57
+ vocab_path = open(file_name, 'r')
58
+ vocab = json.load(vocab_path)
59
+ new_itos = {int(key): value for key, value in vocab['itos'].items()}
60
+
61
+ self.itos = new_itos
62
+ self.stoi = vocab['stoi']
63
+
64
+ def create_vocab(self, file_name='vocab.json'):
65
+ # create json object from dictionary
66
+ vocab = json.dumps({'itos': self.itos,
67
+ 'stoi': self.stoi})
68
+
69
+ # open file for writing, "w"
70
+ f = open(file_name, "w")
71
+
72
+ # write json object to file
73
+ f.write(vocab)
74
+
75
+ # close file
76
+ f.close()
77
+
78
+ def numericalize(self, text):
79
+ tokenized_text = self.tokenizer_eng(text)
80
+
81
+ return [
82
+ self.stoi[token] if token in self.stoi else self.stoi["<UNK>"]
83
+ for token in tokenized_text
84
+ ]
85
+
86
+ class CoCoDataset(Dataset):
87
+ def __init__(self, root_dir, transform=None, freq_threshold=5):
88
+ self.root_dir = root_dir
89
+ self.freq_threshold = freq_threshold
90
+ captions_path = open(os.path.join(self.root_dir, config.captions), 'r')
91
+ captions_file = json.load(captions_path)
92
+ self.transform = transform
93
+
94
+ # Get img, caption columns
95
+ self.imageID_list = [captions['image_id'] for captions in captions_file['annotations']]
96
+ self.captions_list = [captions['caption'] for captions in captions_file['annotations']]
97
+
98
+ # # Initialize vocabulary and build vocab
99
+ # if not self.set_vocab:
100
+ # self.vocab = Vocabulary(self.freq_threshold)
101
+ # self.vocab.build_vocabulary(self.captions_list)
102
+ # self.vocab.create_vocab()
103
+ # else:
104
+ # self.vocab = self.set_vocab
105
+
106
+ # Load vocab file
107
+ self.vocab = Vocabulary(self.freq_threshold)
108
+ self.vocab.read_vocab()
109
+
110
+ def __len__(self):
111
+ return len(self.imageID_list)
112
+
113
+ def __getitem__(self, index):
114
+
115
+ # Load index caption and image
116
+ caption = self.captions_list[index]
117
+ img_id = str((self.imageID_list[index])).zfill(12) + '.jpg'
118
+ self.img = Image.open(os.path.join(self.root_dir, config.images, img_id)).convert("RGB")
119
+
120
+ # Transform image
121
+ if self.transform:
122
+ img = self.transform(self.img)
123
+
124
+ # Numericalized captions
125
+ numericalized_caption = [self.vocab.stoi["<SOS>"]]
126
+ numericalized_caption += self.vocab.numericalize(caption)
127
+ numericalized_caption.append(self.vocab.stoi["<EOS>"])
128
+
129
+ return img, torch.tensor(numericalized_caption)
130
+
131
+ class MyCollate:
132
+ def __init__(self, pad_idx):
133
+ self.pad_idx = pad_idx
134
+
135
+ def __call__(self, batch):
136
+ imgs = [item[0].unsqueeze(0) for item in batch]
137
+ imgs = torch.cat(imgs, dim=0)
138
+ targets = [item[1] for item in batch]
139
+ targets = pad_sequence(targets, batch_first=False, padding_value=self.pad_idx)
140
+
141
+ return imgs, targets
142
+
143
+
144
+ def get_loader(
145
+ root_folder,
146
+ transform,
147
+ batch_size=16,
148
+ num_workers=4,
149
+ shuffle=True,
150
+ pin_memory=True
151
+ ):
152
+ dataset = CoCoDataset(root_folder, transform=transform)
153
+
154
+ pad_idx = dataset.vocab.stoi["<PAD>"]
155
+
156
+ loader = DataLoader(
157
+ dataset=dataset,
158
+ batch_size=batch_size,
159
+ num_workers=num_workers,
160
+ shuffle=shuffle,
161
+ pin_memory=pin_memory,
162
+ collate_fn=MyCollate(pad_idx=pad_idx),
163
+ )
164
+ return dataset, loader
165
+
166
+
167
+
168
+ if __name__ == "__main__":
169
+ transform = transforms.Compose(
170
+ [transforms.Resize((224, 224)), transforms.ToTensor(),]
171
+ )
172
+
173
+ train_dataset, train_loader = get_loader(root_folder=config.train,
174
+ transform=config.transform,
175
+ batch_size=config.batch_size,
176
+ num_workers=config.num_workers,
177
+ shuffle=True)
178
+ from utils import plot_examples
179
+ from model import ImgCaption_Model
180
+ model = ImgCaption_Model(256, 256, len(train_dataset.vocab), 1)
181
+ plot_examples(model, 'cuda', train_dataset, train_dataset.vocab)
182
+ # imgs, captions = dataset.__getitem__(1)
183
+ # print(imgs.shape)
184
+ # print(captions)
185
+ # print(captions.shape)
186
+ # for x, y in loader:
187
+ # a = [[1], [2], [3]]
188
+ # print(a[:-1])
189
+ # print(y[:-1])
190
+ # print(y)
191
+ # break
examples/000000000139.jpg ADDED
examples/000000000785.jpg ADDED
examples/000000005477.jpg ADDED
examples/good1.png ADDED
examples/good2.png ADDED
examples/good3.png ADDED
examples/good6.png ADDED
model.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """
3
+ @author: Van Duc <vvduc03@gmail.com>
4
+ """
5
+ """Import necessary packages"""
6
+ import torch
7
+ import torch.nn as nn
8
+ import torchvision.models as models
9
+
10
+ from torchinfo import summary
11
+
12
+ class CNN(nn.Module):
13
+ def __init__(self, embed_size=256, train_model=False):
14
+ super().__init__()
15
+
16
+ # Load pretrained Efficientnet-B2 model
17
+ self.model = models.efficientnet_b2(weights=models.EfficientNet_B2_Weights)
18
+
19
+ # Frozen all layer of model
20
+ if not train_model:
21
+ for param in self.model.parameters():
22
+ param.requires_grad = False
23
+
24
+ # Replace head of model
25
+ self.model.classifier.requires_grad_(True)
26
+ self.model.classifier = nn.Sequential(nn.Linear(1408, embed_size),
27
+ nn.ReLU(),
28
+ nn.Dropout(0.5))
29
+
30
+ def forward(self, x):
31
+ return self.model(x)
32
+
33
+ class RNN(nn.Module):
34
+ def __init__(self, hidden_size, vocab_size, num_layers, embed_size=256):
35
+ super().__init__()
36
+ # Embedding caption
37
+ self.embed = nn.Embedding(vocab_size, embed_size)
38
+
39
+ # Initialize some necessary layer
40
+ self.lstm = nn.LSTM(embed_size, hidden_size, num_layers)
41
+ self.linear = nn.Linear(hidden_size, vocab_size)
42
+ self.drop_out = nn.Dropout(0.5)
43
+
44
+ def forward(self, features, captions):
45
+ embeddings = self.drop_out(self.embed(captions))
46
+ embeddings = torch.cat((features.unsqueeze(0), embeddings), dim=0)
47
+ hidden, _ = self.lstm(embeddings)
48
+ outputs = self.linear(hidden)
49
+
50
+ return outputs
51
+
52
+ class ImgCaption_Model(nn.Module):
53
+ def __init__(self, embed_size, hidden_size, vocab_size, num_layers):
54
+ super().__init__()
55
+ self.CNN = CNN(embed_size)
56
+ self.RNN = RNN(hidden_size, vocab_size, num_layers, embed_size)
57
+
58
+ def forward(self, images, captions):
59
+
60
+ features = self.CNN(images)
61
+ outputs = self.RNN(features, captions)
62
+
63
+ return outputs
64
+
65
+ def caption_image(self, image, vocab, max_length=50):
66
+ result = []
67
+
68
+ with torch.inference_mode():
69
+ features = self.CNN(image)
70
+ state = None
71
+ for _ in range(max_length):
72
+
73
+ hidden, state = self.RNN.lstm(features, state)
74
+ output = self.RNN.linear(hidden)
75
+ predict = output.argmax(axis=1)
76
+
77
+ if vocab.itos[predict.item()] == "<EOS>":
78
+ break
79
+
80
+ result.append(predict.item())
81
+ features = self.RNN.embed(predict)
82
+
83
+ return [vocab.itos[idx] for idx in result[1:]]
84
+
85
+ if __name__ == '__main__':
86
+ pass
requirement.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ gradio
2
+ torch
3
+ spacy
4
+ torchvision
savedir/best.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:cdac83fb7cfc485259434762661b38b749cfca2df720fd583e36bac929a3a968
3
+ size 104784308
savedir/last.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:bdef3c9612113193c97196cebd5ef3b9115aceb3ba60b572e8897217b3c973cf
3
+ size 104784308
utils.py ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """
3
+ @author: Van Duc <vvduc03@gmail.com>
4
+ """
5
+ """Import necessary packages"""
6
+ import os
7
+ import torch
8
+ import random
9
+ import matplotlib.pyplot as plt
10
+
11
+ def read_caption(num_caption, vocab):
12
+ """
13
+ Convert caption form number to string
14
+ Args:
15
+ num_caption: caption form number
16
+ vocab: vocabulary file
17
+ Returns:
18
+ A list of string (ex: [a, dog, in, the, sky])
19
+ """
20
+ str_caption = []
21
+ for cap in num_caption[1:]:
22
+ if vocab.itos[cap.item()] == "<EOS>":
23
+ break
24
+ str_caption.append(cap)
25
+
26
+ return [vocab.itos[id.item()] for id in str_caption]
27
+
28
+ def plot_examples(model, device, dataset, vocab, num_examples=20):
29
+ """
30
+ Plot image, correct caption and predict caption of some image in dataset
31
+
32
+ Args:
33
+ model: pretrained-model to predict caption
34
+ device: target device cpu and gpu
35
+ dataset: dataset
36
+ vocab: vocabulary
37
+ num_examples: number examples plot
38
+
39
+ Returns:
40
+ Images of picture and caption
41
+ """
42
+ model.eval()
43
+ model.to(device)
44
+
45
+ # Load over examples
46
+ for example in range(num_examples):
47
+ # Take some example from dataset
48
+ image, caption = dataset.__getitem__(random.randint(0, dataset.__len__()))
49
+ image = image.to(device)
50
+
51
+ # Print output
52
+ correct = f"Example {example+1} CORRECT: " + " ".join(read_caption(caption, vocab))
53
+ output = f"Example {example+1} OUTPUT: " + " ".join(model.caption_image(image.unsqueeze(0), vocab))
54
+ print(correct)
55
+ print(output)
56
+ print('----------------------------------------------')
57
+
58
+ # Plot image and caption
59
+ fig, ax = plt.subplots()
60
+ ax.imshow(dataset.img)
61
+ ax.axis('off')
62
+ fig.text(0.5, 0.05,
63
+ correct + '\n' + output,
64
+ ha="center")
65
+
66
+ plt.show()
67
+
68
+ model.train()
69
+
70
+
71
+ def save_checkpoint(model, optimizer, epoch, save_path, last_loss, best_loss):
72
+ print("=> Saving checkpoint")
73
+ checkpoint = {
74
+ "epoch": epoch + 1,
75
+ "model": model.state_dict(),
76
+ "optimizer": optimizer.state_dict()
77
+ }
78
+
79
+ torch.save(checkpoint, os.path.join(save_path, "last.pt"))
80
+ if last_loss < best_loss:
81
+ best_loss = last_loss
82
+ torch.save(checkpoint, os.path.join(save_path, "best.pt"))
83
+
84
+ return best_loss
85
+
86
+ def load_check_point_to_use(checkpoint_file, model, device):
87
+ print("=> Loading checkpoint")
88
+ checkpoint = torch.load(checkpoint_file, map_location=device)
89
+ model.load_state_dict(checkpoint["model"])
90
+
91
+ return model
92
+
93
+ def load_checkpoint_to_continue(checkpoint_file, model, optimizer, lr, device):
94
+ print("=> Loading checkpoint")
95
+ checkpoint = torch.load(checkpoint_file+'/last.pt', map_location=device)
96
+ model.load_state_dict(checkpoint["model"])
97
+ optimizer.load_state_dict(checkpoint["optimizer"])
98
+ epoch = checkpoint["epoch"]
99
+
100
+ # If we don't do this then it will just have learning rate of old checkpoint
101
+ # and it will lead to many hours of debugging \:
102
+ for param_group in optimizer.param_groups:
103
+ param_group["lr"] = lr
104
+
105
+ return model, epoch
vocab.json ADDED
The diff for this file is too large to render. See raw diff