Spaces:
Build error
Build error
ovi054
commited on
Commit
•
00abfdc
1
Parent(s):
b1fd26f
first commit
Browse files- app.py +98 -0
- data_loader.py +152 -0
- model.py +64 -0
- models/decoder-3.pkl +3 -0
- models/encoder-3.pkl +3 -0
- models/vocab.pkl +3 -0
- requirements.txt +1 -0
- vocabulary.py +95 -0
app.py
ADDED
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from PIL import Image
|
2 |
+
import numpy as np
|
3 |
+
from torchvision import transforms
|
4 |
+
import torch
|
5 |
+
|
6 |
+
from data_loader import get_loader
|
7 |
+
|
8 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
9 |
+
# Watch for any changes in model.py, and re-load it automatically.
|
10 |
+
%load_ext autoreload
|
11 |
+
%autoreload 2
|
12 |
+
|
13 |
+
import os
|
14 |
+
import torch
|
15 |
+
from model import EncoderCNN, DecoderRNN
|
16 |
+
|
17 |
+
# TODO #2: Specify the saved models to load.
|
18 |
+
encoder_file = 'encoder-3.pkl'
|
19 |
+
decoder_file = 'decoder-3.pkl'
|
20 |
+
|
21 |
+
# TODO #3: Select appropriate values for the Python variables below.
|
22 |
+
embed_size = 256
|
23 |
+
hidden_size = 512
|
24 |
+
|
25 |
+
# The size of the vocabulary.
|
26 |
+
vocab_size = 8855
|
27 |
+
|
28 |
+
# Initialize the encoder and decoder, and set each to inference mode.
|
29 |
+
encoder = EncoderCNN(embed_size)
|
30 |
+
encoder.eval()
|
31 |
+
decoder = DecoderRNN(embed_size, hidden_size, vocab_size)
|
32 |
+
decoder.eval()
|
33 |
+
|
34 |
+
# Load the trained weights.
|
35 |
+
encoder.load_state_dict(torch.load(os.path.join('/models', encoder_file), map_location=torch.device('cpu')))
|
36 |
+
decoder.load_state_dict(torch.load(os.path.join('/models', decoder_file), map_location=torch.device('cpu')))
|
37 |
+
|
38 |
+
# Move models to GPU if CUDA is available.
|
39 |
+
encoder.to(device)
|
40 |
+
decoder.to(device)
|
41 |
+
|
42 |
+
|
43 |
+
def process_image(image):
|
44 |
+
''' Scales, crops, and normalizes a PIL image for a PyTorch model
|
45 |
+
|
46 |
+
'''
|
47 |
+
|
48 |
+
#img = Image.open(image)
|
49 |
+
transformation = transforms.Compose([
|
50 |
+
transforms.Resize(256), # smaller edge of image resized to 256
|
51 |
+
transforms.RandomCrop(224), # get 224x224 crop from random location
|
52 |
+
transforms.ToTensor(), # convert the PIL Image to a tensor
|
53 |
+
transforms.Normalize((0.485, 0.456, 0.406), # normalize image for pre-trained model
|
54 |
+
(0.229, 0.224, 0.225))])
|
55 |
+
return transformation(image)
|
56 |
+
|
57 |
+
|
58 |
+
def function(img_np):
|
59 |
+
PIL_image = Image.fromarray(img_np).convert('RGB')
|
60 |
+
orig_image = np.array(PIL_image)
|
61 |
+
image = process_image(PIL_image)
|
62 |
+
|
63 |
+
# return original image and pre-processed image tensor
|
64 |
+
return orig_image, image
|
65 |
+
|
66 |
+
def clean_sentence(output):
|
67 |
+
sentense = ''
|
68 |
+
for i in output:
|
69 |
+
word = data_loader.dataset.vocab.idx2word[i]
|
70 |
+
if i == 0:
|
71 |
+
continue
|
72 |
+
if i == 1:
|
73 |
+
break
|
74 |
+
if i == 18:
|
75 |
+
sentense = sentense + word
|
76 |
+
else:
|
77 |
+
sentense = sentense + ' ' + word
|
78 |
+
|
79 |
+
return sentense.strip()
|
80 |
+
|
81 |
+
data_loader = get_loader(transform=transforms, mode='test')
|
82 |
+
|
83 |
+
def get_caption(image):
|
84 |
+
orig_image, image = function('image')
|
85 |
+
image =image.unsqueeze(0)
|
86 |
+
plt.imshow(np.squeeze(orig_image))
|
87 |
+
plt.title('Sample Image')
|
88 |
+
plt.show()
|
89 |
+
image = image.to(device)
|
90 |
+
features = encoder(image).unsqueeze(1)
|
91 |
+
output = decoder.sample(features)
|
92 |
+
sentence = clean_sentence(output)
|
93 |
+
return sentence
|
94 |
+
|
95 |
+
import gradio as gr
|
96 |
+
|
97 |
+
demo = gr.Interface(fn=get_caption, inputs= "image", outputs="image")
|
98 |
+
demo.launch()
|
data_loader.py
ADDED
@@ -0,0 +1,152 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import nltk
|
2 |
+
import os
|
3 |
+
import torch
|
4 |
+
import torch.utils.data as data
|
5 |
+
from vocabulary import Vocabulary
|
6 |
+
from PIL import Image
|
7 |
+
from pycocotools.coco import COCO
|
8 |
+
import numpy as np
|
9 |
+
from tqdm import tqdm
|
10 |
+
import random
|
11 |
+
import json
|
12 |
+
|
13 |
+
def get_loader(transform,
|
14 |
+
mode='train',
|
15 |
+
batch_size=1,
|
16 |
+
vocab_threshold=None,
|
17 |
+
vocab_file='/models/vocab.pkl',
|
18 |
+
start_word="<start>",
|
19 |
+
end_word="<end>",
|
20 |
+
unk_word="<unk>",
|
21 |
+
vocab_from_file=True,
|
22 |
+
num_workers=0,
|
23 |
+
cocoapi_loc='/opt'):
|
24 |
+
"""Returns the data loader.
|
25 |
+
Args:
|
26 |
+
transform: Image transform.
|
27 |
+
mode: One of 'train' or 'test'.
|
28 |
+
batch_size: Batch size (if in testing mode, must have batch_size=1).
|
29 |
+
vocab_threshold: Minimum word count threshold.
|
30 |
+
vocab_file: File containing the vocabulary.
|
31 |
+
start_word: Special word denoting sentence start.
|
32 |
+
end_word: Special word denoting sentence end.
|
33 |
+
unk_word: Special word denoting unknown words.
|
34 |
+
vocab_from_file: If False, create vocab from scratch & override any existing vocab_file.
|
35 |
+
If True, load vocab from from existing vocab_file, if it exists.
|
36 |
+
num_workers: Number of subprocesses to use for data loading
|
37 |
+
cocoapi_loc: The location of the folder containing the COCO API: https://github.com/cocodataset/cocoapi
|
38 |
+
"""
|
39 |
+
|
40 |
+
assert mode in ['train', 'test'], "mode must be one of 'train' or 'test'."
|
41 |
+
if vocab_from_file==False: assert mode=='train', "To generate vocab from captions file, must be in training mode (mode='train')."
|
42 |
+
|
43 |
+
# Based on mode (train, val, test), obtain img_folder and annotations_file.
|
44 |
+
if mode == 'train':
|
45 |
+
if vocab_from_file==True: assert os.path.exists(vocab_file), "vocab_file does not exist. Change vocab_from_file to False to create vocab_file."
|
46 |
+
img_folder = os.path.join(cocoapi_loc, 'cocoapi/images/train2014/')
|
47 |
+
annotations_file = os.path.join(cocoapi_loc, 'cocoapi/annotations/captions_train2014.json')
|
48 |
+
if mode == 'test':
|
49 |
+
assert batch_size==1, "Please change batch_size to 1 if testing your model."
|
50 |
+
assert os.path.exists(vocab_file), "Must first generate vocab.pkl from training data."
|
51 |
+
assert vocab_from_file==True, "Change vocab_from_file to True."
|
52 |
+
img_folder = '/content/opt/cocoapi/images/test2014'
|
53 |
+
annotations_file = '/content/gdrive/MyDrive/image_info_test2014.json'
|
54 |
+
|
55 |
+
# COCO caption dataset.
|
56 |
+
dataset = CoCoDataset(transform=transform,
|
57 |
+
mode=mode,
|
58 |
+
batch_size=batch_size,
|
59 |
+
vocab_threshold=vocab_threshold,
|
60 |
+
vocab_file=vocab_file,
|
61 |
+
start_word=start_word,
|
62 |
+
end_word=end_word,
|
63 |
+
unk_word=unk_word,
|
64 |
+
annotations_file=annotations_file,
|
65 |
+
vocab_from_file=vocab_from_file,
|
66 |
+
img_folder=img_folder)
|
67 |
+
|
68 |
+
if mode == 'train':
|
69 |
+
# Randomly sample a caption length, and sample indices with that length.
|
70 |
+
indices = dataset.get_train_indices()
|
71 |
+
# Create and assign a batch sampler to retrieve a batch with the sampled indices.
|
72 |
+
initial_sampler = data.sampler.SubsetRandomSampler(indices=indices)
|
73 |
+
# data loader for COCO dataset.
|
74 |
+
data_loader = data.DataLoader(dataset=dataset,
|
75 |
+
num_workers=num_workers,
|
76 |
+
batch_sampler=data.sampler.BatchSampler(sampler=initial_sampler,
|
77 |
+
batch_size=dataset.batch_size,
|
78 |
+
drop_last=False))
|
79 |
+
else:
|
80 |
+
data_loader = data.DataLoader(dataset=dataset,
|
81 |
+
batch_size=dataset.batch_size,
|
82 |
+
shuffle=True,
|
83 |
+
num_workers=num_workers)
|
84 |
+
|
85 |
+
return data_loader
|
86 |
+
|
87 |
+
class CoCoDataset(data.Dataset):
|
88 |
+
|
89 |
+
def __init__(self, transform, mode, batch_size, vocab_threshold, vocab_file, start_word,
|
90 |
+
end_word, unk_word, annotations_file, vocab_from_file, img_folder):
|
91 |
+
self.transform = transform
|
92 |
+
self.mode = mode
|
93 |
+
self.batch_size = batch_size
|
94 |
+
self.vocab = Vocabulary(vocab_threshold, vocab_file, start_word,
|
95 |
+
end_word, unk_word, annotations_file, vocab_from_file)
|
96 |
+
self.img_folder = img_folder
|
97 |
+
if self.mode == 'train':
|
98 |
+
self.coco = COCO(annotations_file)
|
99 |
+
self.ids = list(self.coco.anns.keys())
|
100 |
+
print('Obtaining caption lengths...')
|
101 |
+
all_tokens = [nltk.tokenize.word_tokenize(str(self.coco.anns[self.ids[index]]['caption']).lower()) for index in tqdm(np.arange(len(self.ids)))]
|
102 |
+
self.caption_lengths = [len(token) for token in all_tokens]
|
103 |
+
else:
|
104 |
+
test_info = json.loads(open(annotations_file).read())
|
105 |
+
self.paths = [item['file_name'] for item in test_info['images']]
|
106 |
+
|
107 |
+
def __getitem__(self, index):
|
108 |
+
# obtain image and caption if in training mode
|
109 |
+
if self.mode == 'train':
|
110 |
+
ann_id = self.ids[index]
|
111 |
+
caption = self.coco.anns[ann_id]['caption']
|
112 |
+
img_id = self.coco.anns[ann_id]['image_id']
|
113 |
+
path = self.coco.loadImgs(img_id)[0]['file_name']
|
114 |
+
|
115 |
+
# Convert image to tensor and pre-process using transform
|
116 |
+
image = Image.open(os.path.join(self.img_folder, path)).convert('RGB')
|
117 |
+
image = self.transform(image)
|
118 |
+
|
119 |
+
# Convert caption to tensor of word ids.
|
120 |
+
tokens = nltk.tokenize.word_tokenize(str(caption).lower())
|
121 |
+
caption = []
|
122 |
+
caption.append(self.vocab(self.vocab.start_word))
|
123 |
+
caption.extend([self.vocab(token) for token in tokens])
|
124 |
+
caption.append(self.vocab(self.vocab.end_word))
|
125 |
+
caption = torch.Tensor(caption).long()
|
126 |
+
|
127 |
+
# return pre-processed image and caption tensors
|
128 |
+
return image, caption
|
129 |
+
|
130 |
+
# obtain image if in test mode
|
131 |
+
else:
|
132 |
+
path = self.paths[index]
|
133 |
+
|
134 |
+
# Convert image to tensor and pre-process using transform
|
135 |
+
PIL_image = Image.open(os.path.join(self.img_folder, path)).convert('RGB')
|
136 |
+
orig_image = np.array(PIL_image)
|
137 |
+
image = self.transform(PIL_image)
|
138 |
+
|
139 |
+
# return original image and pre-processed image tensor
|
140 |
+
return orig_image, image
|
141 |
+
|
142 |
+
def get_train_indices(self):
|
143 |
+
sel_length = np.random.choice(self.caption_lengths)
|
144 |
+
all_indices = np.where([self.caption_lengths[i] == sel_length for i in np.arange(len(self.caption_lengths))])[0]
|
145 |
+
indices = list(np.random.choice(all_indices, size=self.batch_size))
|
146 |
+
return indices
|
147 |
+
|
148 |
+
def __len__(self):
|
149 |
+
if self.mode == 'train':
|
150 |
+
return len(self.ids)
|
151 |
+
else:
|
152 |
+
return len(self.paths)
|
model.py
ADDED
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torchvision.models as models
|
4 |
+
|
5 |
+
|
6 |
+
class EncoderCNN(nn.Module):
|
7 |
+
def __init__(self, embed_size):
|
8 |
+
super(EncoderCNN, self).__init__()
|
9 |
+
resnet = models.resnet50(pretrained=True)
|
10 |
+
for param in resnet.parameters():
|
11 |
+
param.requires_grad_(False)
|
12 |
+
|
13 |
+
modules = list(resnet.children())[:-1]
|
14 |
+
self.resnet = nn.Sequential(*modules)
|
15 |
+
self.embed = nn.Linear(resnet.fc.in_features, embed_size)
|
16 |
+
|
17 |
+
def forward(self, images):
|
18 |
+
features = self.resnet(images)
|
19 |
+
features = features.view(features.size(0), -1)
|
20 |
+
features = self.embed(features)
|
21 |
+
return features
|
22 |
+
|
23 |
+
|
24 |
+
class DecoderRNN(nn.Module):
|
25 |
+
def __init__(self, embed_size, hidden_size, vocab_size, num_layers=1):
|
26 |
+
super(DecoderRNN, self).__init__()
|
27 |
+
|
28 |
+
self.hidden_dim = hidden_size
|
29 |
+
|
30 |
+
self.embed = nn.Embedding(vocab_size, embed_size)
|
31 |
+
self.lstm = nn.LSTM(embed_size, hidden_size, num_layers, batch_first=True)
|
32 |
+
self.linear = nn.Linear(hidden_size, vocab_size)
|
33 |
+
self.hidden = (torch.zeros(1, 1, hidden_size),torch.zeros(1, 1, hidden_size))
|
34 |
+
|
35 |
+
def forward(self, features, captions):
|
36 |
+
cap_embedding = self.embed(captions[:,:-1])
|
37 |
+
embeddings = torch.cat((features.unsqueeze(1), cap_embedding), 1)
|
38 |
+
#print('in decoderrnn forward, embedding shape ', embeddings.shape)
|
39 |
+
#packed = pack_padded_sequence(embeddings, lengths, batch_first=True)
|
40 |
+
|
41 |
+
#lstm_out, self.hidden = self.lstm(embeddings, self.hidden)
|
42 |
+
#lstm_out, self.hidden = self.lstm(embeddings.view(len(embeddings), 1, -1), self.hidden)
|
43 |
+
lstm_out, self.hidden = self.lstm(embeddings)
|
44 |
+
outputs = self.linear(lstm_out)
|
45 |
+
|
46 |
+
#return outputs[:,1:,:]
|
47 |
+
return outputs
|
48 |
+
|
49 |
+
|
50 |
+
def sample(self, inputs, hidden=None, max_len=20):
|
51 |
+
" accepts pre-processed image tensor (inputs) and returns predicted sentence (list of tensor ids of length max_len) "
|
52 |
+
res = []
|
53 |
+
for i in range(max_len):
|
54 |
+
outputs, hidden = self.lstm(inputs, hidden)
|
55 |
+
# print('lstm output shape ', outputs.shape)
|
56 |
+
# print('lstm output.squeeze(1) shape ', outputs.squeeze(1).shape)
|
57 |
+
outputs = self.linear(outputs.squeeze(1))
|
58 |
+
# print('linear output shape ', outputs.shape)
|
59 |
+
target_index = outputs.max(1)[1]
|
60 |
+
# print('target_index shape ', target_index.shape)
|
61 |
+
res.append(target_index.item())
|
62 |
+
inputs = self.embed(target_index).unsqueeze(1)
|
63 |
+
# print('new inputs shape ', inputs.shape, '\n')
|
64 |
+
return res
|
models/decoder-3.pkl
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:d3bb7e576b01e41b358e9b43a823a89e71331cc3854733eae590fcbb631e3b0f
|
3 |
+
size 33546937
|
models/encoder-3.pkl
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:7ab1ce1da221b8eb5ab980ba434c7ffb52beb5a18fa569aa1223ade37e8e752b
|
3 |
+
size 96387105
|
models/vocab.pkl
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:b4217bb73c18c91e7056154c08ee5c24f14adc9d91c334d63fb370b443f3eaa3
|
3 |
+
size 242231
|
requirements.txt
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
git+https://github.com/philferriere/cocoapi.git#egg=pycocotools&subdirectory=PythonAPI
|
vocabulary.py
ADDED
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import nltk
|
2 |
+
import pickle
|
3 |
+
import os.path
|
4 |
+
from pycocotools.coco import COCO
|
5 |
+
from collections import Counter
|
6 |
+
|
7 |
+
class Vocabulary(object):
|
8 |
+
|
9 |
+
def __init__(self,
|
10 |
+
vocab_threshold,
|
11 |
+
vocab_file='/models/vocab.pkl',
|
12 |
+
start_word="<start>",
|
13 |
+
end_word="<end>",
|
14 |
+
unk_word="<unk>",
|
15 |
+
annotations_file='../cocoapi/annotations/captions_train2014.json',
|
16 |
+
vocab_from_file=False):
|
17 |
+
"""Initialize the vocabulary.
|
18 |
+
Args:
|
19 |
+
vocab_threshold: Minimum word count threshold.
|
20 |
+
vocab_file: File containing the vocabulary.
|
21 |
+
start_word: Special word denoting sentence start.
|
22 |
+
end_word: Special word denoting sentence end.
|
23 |
+
unk_word: Special word denoting unknown words.
|
24 |
+
annotations_file: Path for train annotation file.
|
25 |
+
vocab_from_file: If False, create vocab from scratch & override any existing vocab_file
|
26 |
+
If True, load vocab from from existing vocab_file, if it exists
|
27 |
+
"""
|
28 |
+
self.vocab_threshold = vocab_threshold
|
29 |
+
self.vocab_file = vocab_file
|
30 |
+
self.start_word = start_word
|
31 |
+
self.end_word = end_word
|
32 |
+
self.unk_word = unk_word
|
33 |
+
self.annotations_file = annotations_file
|
34 |
+
self.vocab_from_file = vocab_from_file
|
35 |
+
self.get_vocab()
|
36 |
+
|
37 |
+
def get_vocab(self):
|
38 |
+
"""Load the vocabulary from file OR build the vocabulary from scratch."""
|
39 |
+
if os.path.exists(self.vocab_file) & self.vocab_from_file:
|
40 |
+
with open(self.vocab_file, 'rb') as f:
|
41 |
+
vocab = pickle.load(f)
|
42 |
+
self.word2idx = vocab.word2idx
|
43 |
+
self.idx2word = vocab.idx2word
|
44 |
+
print('Vocabulary successfully loaded from vocab.pkl file!')
|
45 |
+
else:
|
46 |
+
self.build_vocab()
|
47 |
+
with open(self.vocab_file, 'wb') as f:
|
48 |
+
pickle.dump(self, f)
|
49 |
+
|
50 |
+
def build_vocab(self):
|
51 |
+
"""Populate the dictionaries for converting tokens to integers (and vice-versa)."""
|
52 |
+
self.init_vocab()
|
53 |
+
self.add_word(self.start_word)
|
54 |
+
self.add_word(self.end_word)
|
55 |
+
self.add_word(self.unk_word)
|
56 |
+
self.add_captions()
|
57 |
+
|
58 |
+
def init_vocab(self):
|
59 |
+
"""Initialize the dictionaries for converting tokens to integers (and vice-versa)."""
|
60 |
+
self.word2idx = {}
|
61 |
+
self.idx2word = {}
|
62 |
+
self.idx = 0
|
63 |
+
|
64 |
+
def add_word(self, word):
|
65 |
+
"""Add a token to the vocabulary."""
|
66 |
+
if not word in self.word2idx:
|
67 |
+
self.word2idx[word] = self.idx
|
68 |
+
self.idx2word[self.idx] = word
|
69 |
+
self.idx += 1
|
70 |
+
|
71 |
+
def add_captions(self):
|
72 |
+
"""Loop over training captions and add all tokens to the vocabulary that meet or exceed the threshold."""
|
73 |
+
coco = COCO(self.annotations_file)
|
74 |
+
counter = Counter()
|
75 |
+
ids = coco.anns.keys()
|
76 |
+
for i, id in enumerate(ids):
|
77 |
+
caption = str(coco.anns[id]['caption'])
|
78 |
+
tokens = nltk.tokenize.word_tokenize(caption.lower())
|
79 |
+
counter.update(tokens)
|
80 |
+
|
81 |
+
if i % 100000 == 0:
|
82 |
+
print("[%d/%d] Tokenizing captions..." % (i, len(ids)))
|
83 |
+
|
84 |
+
words = [word for word, cnt in counter.items() if cnt >= self.vocab_threshold]
|
85 |
+
|
86 |
+
for i, word in enumerate(words):
|
87 |
+
self.add_word(word)
|
88 |
+
|
89 |
+
def __call__(self, word):
|
90 |
+
if not word in self.word2idx:
|
91 |
+
return self.word2idx[self.unk_word]
|
92 |
+
return self.word2idx[word]
|
93 |
+
|
94 |
+
def __len__(self):
|
95 |
+
return len(self.word2idx)
|