Spaces:
Runtime error
Runtime error
Alberto Carmona
commited on
Commit
•
ebd4e51
1
Parent(s):
7ec5667
Track error cloning the repo
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- captioning/__init__.py +0 -0
- captioning/data/__init__.py +0 -0
- captioning/data/dataloader.py +425 -0
- captioning/data/pth_loader.py +334 -0
- captioning/data/pth_loader_FineCapEval.py +334 -0
- captioning/models/AoAModel.py +228 -0
- captioning/models/AttEnsemble.py +90 -0
- captioning/models/AttModel.py +969 -0
- captioning/models/BertCapModel.py +104 -0
- captioning/models/CaptionModel.py +407 -0
- captioning/models/FCModel.py +204 -0
- captioning/models/M2Transformer.py +98 -0
- captioning/models/ShowTellModel.py +174 -0
- captioning/models/TransformerModel.py +363 -0
- captioning/models/__init__.py +73 -0
- captioning/models/cachedTransformer.py +420 -0
- captioning/models/utils.py +25 -0
- captioning/modules/loss_wrapper.py +127 -0
- captioning/modules/losses.py +218 -0
- captioning/utils/__init__.py +0 -0
- captioning/utils/clipscore.py +396 -0
- captioning/utils/config.py +153 -0
- captioning/utils/dist_utils.py +305 -0
- captioning/utils/div_utils.py +38 -0
- captioning/utils/eval_multi.py +218 -0
- captioning/utils/eval_utils.py +281 -0
- captioning/utils/misc.py +251 -0
- captioning/utils/opts.py +412 -0
- captioning/utils/resnet.py +71 -0
- captioning/utils/resnet_utils.py +27 -0
- captioning/utils/rewards.py +392 -0
- captioning/utils/utils.py +138 -0
- clip/__init__.py +1 -0
- clip/bpe_simple_vocab_16e6.txt.gz +3 -0
- clip/clip.py +193 -0
- clip/model.py +437 -0
- clip/simple_tokenizer.py +132 -0
- configs/phase1/FineCapEval_clipRN50_mle.yml +60 -0
- configs/phase1/clipRN50_mle.yml +52 -0
- configs/phase1/transformer.yml +41 -0
- configs/phase2/FineCapEval_clipRN50_cider.yml +61 -0
- configs/phase2/FineCapEval_clipRN50_cider_clips.yml +65 -0
- configs/phase2/FineCapEval_clipRN50_clips.yml +64 -0
- configs/phase2/FineCapEval_clipRN50_clips_grammar.yml +64 -0
- configs/phase2/clipRN50_cider.yml +58 -0
- configs/phase2/clipRN50_cider_clips.yml +61 -0
- configs/phase2/clipRN50_clips.yml +58 -0
- configs/phase2/clipRN50_clips_grammar.yml +64 -0
- configs/phase2/transformer.yml +41 -0
- data/README.md +1 -0
captioning/__init__.py
ADDED
File without changes
|
captioning/data/__init__.py
ADDED
File without changes
|
captioning/data/dataloader.py
ADDED
@@ -0,0 +1,425 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import absolute_import
|
2 |
+
from __future__ import division
|
3 |
+
from __future__ import print_function
|
4 |
+
|
5 |
+
import json
|
6 |
+
import h5py
|
7 |
+
from lmdbdict import lmdbdict
|
8 |
+
from lmdbdict.methods import DUMPS_FUNC, LOADS_FUNC
|
9 |
+
import os
|
10 |
+
import numpy as np
|
11 |
+
import numpy.random as npr
|
12 |
+
import random
|
13 |
+
from functools import partial
|
14 |
+
|
15 |
+
import torch
|
16 |
+
import torch.utils.data as data
|
17 |
+
|
18 |
+
import multiprocessing
|
19 |
+
import six
|
20 |
+
|
21 |
+
class HybridLoader:
|
22 |
+
"""
|
23 |
+
If db_path is a director, then use normal file loading
|
24 |
+
If lmdb, then load from lmdb
|
25 |
+
The loading method depend on extention.
|
26 |
+
|
27 |
+
in_memory: if in_memory is True, we save all the features in memory
|
28 |
+
For individual np(y|z)s, we don't need to do that because the system will do this for us.
|
29 |
+
Should be useful for lmdb or h5.
|
30 |
+
(Copied this idea from vilbert)
|
31 |
+
"""
|
32 |
+
def __init__(self, db_path, ext, in_memory=False):
|
33 |
+
self.db_path = db_path
|
34 |
+
self.ext = ext
|
35 |
+
if self.ext == '.npy':
|
36 |
+
self.loader = lambda x: np.load(six.BytesIO(x))
|
37 |
+
else:
|
38 |
+
def load_npz(x):
|
39 |
+
x = np.load(six.BytesIO(x))
|
40 |
+
return x['feat'] if 'feat' in x else x['z'] # normally it should be 'feat', but under cocotest_bu, the key is saved to be 'z' mistakenly.
|
41 |
+
self.loader = load_npz
|
42 |
+
if db_path.endswith('.lmdb'):
|
43 |
+
self.db_type = 'lmdb'
|
44 |
+
self.lmdb = lmdbdict(db_path, unsafe=True)
|
45 |
+
self.lmdb._key_dumps = DUMPS_FUNC['ascii']
|
46 |
+
self.lmdb._value_loads = LOADS_FUNC['identity']
|
47 |
+
elif db_path.endswith('.pth'): # Assume a key,value dictionary
|
48 |
+
self.db_type = 'pth'
|
49 |
+
self.feat_file = torch.load(db_path)
|
50 |
+
self.loader = lambda x: x
|
51 |
+
print('HybridLoader: ext is ignored')
|
52 |
+
elif db_path.endswith('h5'):
|
53 |
+
self.db_type = 'h5'
|
54 |
+
self.loader = lambda x: np.array(x).astype('float32')
|
55 |
+
else:
|
56 |
+
self.db_type = 'dir'
|
57 |
+
|
58 |
+
self.in_memory = in_memory
|
59 |
+
if self.in_memory:
|
60 |
+
self.features = {}
|
61 |
+
|
62 |
+
def get(self, key):
|
63 |
+
|
64 |
+
if self.in_memory and key in self.features:
|
65 |
+
# We save f_input because we want to save the
|
66 |
+
# compressed bytes to save memory
|
67 |
+
f_input = self.features[key]
|
68 |
+
elif self.db_type == 'lmdb':
|
69 |
+
f_input = self.lmdb[key]
|
70 |
+
elif self.db_type == 'pth':
|
71 |
+
f_input = self.feat_file[key]
|
72 |
+
elif self.db_type == 'h5':
|
73 |
+
f_input = h5py.File(self.db_path, 'r')[key]
|
74 |
+
else:
|
75 |
+
f_input = open(os.path.join(self.db_path, key + self.ext), 'rb').read()
|
76 |
+
|
77 |
+
if self.in_memory and key not in self.features:
|
78 |
+
self.features[key] = f_input
|
79 |
+
|
80 |
+
# load image
|
81 |
+
feat = self.loader(f_input)
|
82 |
+
|
83 |
+
return feat
|
84 |
+
|
85 |
+
class Dataset(data.Dataset):
|
86 |
+
|
87 |
+
def get_vocab_size(self):
|
88 |
+
return self.vocab_size
|
89 |
+
|
90 |
+
def get_vocab(self):
|
91 |
+
return self.ix_to_word
|
92 |
+
|
93 |
+
def get_seq_length(self):
|
94 |
+
return self.seq_length
|
95 |
+
|
96 |
+
def __init__(self, opt):
|
97 |
+
self.opt = opt
|
98 |
+
self.seq_per_img = opt.seq_per_img
|
99 |
+
|
100 |
+
# feature related options
|
101 |
+
self.use_fc = getattr(opt, 'use_fc', True)
|
102 |
+
self.use_att = getattr(opt, 'use_att', True)
|
103 |
+
self.use_box = getattr(opt, 'use_box', 0)
|
104 |
+
self.norm_att_feat = getattr(opt, 'norm_att_feat', 0)
|
105 |
+
self.norm_box_feat = getattr(opt, 'norm_box_feat', 0)
|
106 |
+
|
107 |
+
# load the json file which contains additional information about the dataset
|
108 |
+
print('DataLoader loading json file: ', opt.input_json)
|
109 |
+
self.info = json.load(open(self.opt.input_json))
|
110 |
+
if 'ix_to_word' in self.info:
|
111 |
+
self.ix_to_word = self.info['ix_to_word']
|
112 |
+
self.vocab_size = len(self.ix_to_word)
|
113 |
+
print('vocab size is ', self.vocab_size)
|
114 |
+
|
115 |
+
# open the hdf5 file
|
116 |
+
print('DataLoader loading h5 file: ', opt.input_fc_dir, opt.input_att_dir, opt.input_box_dir, opt.input_label_h5)
|
117 |
+
"""
|
118 |
+
Setting input_label_h5 to none is used when only doing generation.
|
119 |
+
For example, when you need to test on coco test set.
|
120 |
+
"""
|
121 |
+
if self.opt.input_label_h5 != 'none':
|
122 |
+
self.h5_label_file = h5py.File(self.opt.input_label_h5, 'r', driver='core')
|
123 |
+
# load in the sequence data
|
124 |
+
seq_size = self.h5_label_file['labels'].shape
|
125 |
+
self.label = self.h5_label_file['labels'][:]
|
126 |
+
self.seq_length = seq_size[1]
|
127 |
+
print('max sequence length in data is', self.seq_length)
|
128 |
+
# load the pointers in full to RAM (should be small enough)
|
129 |
+
self.label_start_ix = self.h5_label_file['label_start_ix'][:]
|
130 |
+
self.label_end_ix = self.h5_label_file['label_end_ix'][:]
|
131 |
+
else:
|
132 |
+
self.seq_length = 1
|
133 |
+
|
134 |
+
self.data_in_memory = getattr(opt, 'data_in_memory', False)
|
135 |
+
self.fc_loader = HybridLoader(self.opt.input_fc_dir, '.npy', in_memory=self.data_in_memory)
|
136 |
+
self.att_loader = HybridLoader(self.opt.input_att_dir, '.npz', in_memory=self.data_in_memory)
|
137 |
+
self.box_loader = HybridLoader(self.opt.input_box_dir, '.npy', in_memory=self.data_in_memory)
|
138 |
+
|
139 |
+
self.num_images = len(self.info['images']) # self.label_start_ix.shape[0]
|
140 |
+
print('read %d image features' %(self.num_images))
|
141 |
+
|
142 |
+
# separate out indexes for each of the provided splits
|
143 |
+
self.split_ix = {'train': [], 'val': [], 'test': []}
|
144 |
+
for ix in range(len(self.info['images'])):
|
145 |
+
img = self.info['images'][ix]
|
146 |
+
if not 'split' in img:
|
147 |
+
self.split_ix['train'].append(ix)
|
148 |
+
self.split_ix['val'].append(ix)
|
149 |
+
self.split_ix['test'].append(ix)
|
150 |
+
elif img['split'] == 'train':
|
151 |
+
self.split_ix['train'].append(ix)
|
152 |
+
elif img['split'] == 'val':
|
153 |
+
self.split_ix['val'].append(ix)
|
154 |
+
elif img['split'] == 'test':
|
155 |
+
self.split_ix['test'].append(ix)
|
156 |
+
elif opt.train_only == 0: # restval
|
157 |
+
self.split_ix['train'].append(ix)
|
158 |
+
|
159 |
+
print('assigned %d images to split train' %len(self.split_ix['train']))
|
160 |
+
print('assigned %d images to split val' %len(self.split_ix['val']))
|
161 |
+
print('assigned %d images to split test' %len(self.split_ix['test']))
|
162 |
+
|
163 |
+
def get_captions(self, ix, seq_per_img):
|
164 |
+
# fetch the sequence labels
|
165 |
+
ix1 = self.label_start_ix[ix] - 1 #label_start_ix starts from 1
|
166 |
+
ix2 = self.label_end_ix[ix] - 1
|
167 |
+
ncap = ix2 - ix1 + 1 # number of captions available for this image
|
168 |
+
assert ncap > 0, 'an image does not have any label. this can be handled but right now isn\'t'
|
169 |
+
|
170 |
+
if ncap < seq_per_img:
|
171 |
+
# we need to subsample (with replacement)
|
172 |
+
seq = np.zeros([seq_per_img, self.seq_length], dtype = 'int')
|
173 |
+
for q in range(seq_per_img):
|
174 |
+
ixl = random.randint(ix1,ix2)
|
175 |
+
seq[q, :] = self.label[ixl, :self.seq_length]
|
176 |
+
else:
|
177 |
+
ixl = random.randint(ix1, ix2 - seq_per_img + 1)
|
178 |
+
seq = self.label[ixl: ixl + seq_per_img, :self.seq_length]
|
179 |
+
|
180 |
+
return seq
|
181 |
+
|
182 |
+
def collate_func(self, batch, split):
|
183 |
+
seq_per_img = self.seq_per_img
|
184 |
+
|
185 |
+
fc_batch = []
|
186 |
+
att_batch = []
|
187 |
+
label_batch = []
|
188 |
+
|
189 |
+
wrapped = False
|
190 |
+
|
191 |
+
infos = []
|
192 |
+
gts = []
|
193 |
+
|
194 |
+
for sample in batch:
|
195 |
+
# fetch image
|
196 |
+
tmp_fc, tmp_att, tmp_seq, \
|
197 |
+
ix, it_pos_now, tmp_wrapped = sample
|
198 |
+
if tmp_wrapped:
|
199 |
+
wrapped = True
|
200 |
+
|
201 |
+
fc_batch.append(tmp_fc)
|
202 |
+
att_batch.append(tmp_att)
|
203 |
+
|
204 |
+
tmp_label = np.zeros([seq_per_img, self.seq_length + 2], dtype = 'int')
|
205 |
+
if hasattr(self, 'h5_label_file'):
|
206 |
+
# if there is ground truth
|
207 |
+
tmp_label[:, 1 : self.seq_length + 1] = tmp_seq
|
208 |
+
label_batch.append(tmp_label)
|
209 |
+
|
210 |
+
# Used for reward evaluation
|
211 |
+
if hasattr(self, 'h5_label_file'):
|
212 |
+
# if there is ground truth
|
213 |
+
gts.append(self.label[self.label_start_ix[ix] - 1: self.label_end_ix[ix]])
|
214 |
+
else:
|
215 |
+
gts.append([])
|
216 |
+
|
217 |
+
# record associated info as well
|
218 |
+
info_dict = {}
|
219 |
+
info_dict['ix'] = ix
|
220 |
+
info_dict['id'] = self.info['images'][ix]['id']
|
221 |
+
info_dict['file_path'] = self.info['images'][ix].get('file_path', '')
|
222 |
+
infos.append(info_dict)
|
223 |
+
|
224 |
+
# #sort by att_feat length
|
225 |
+
# fc_batch, att_batch, label_batch, gts, infos = \
|
226 |
+
# zip(*sorted(zip(fc_batch, att_batch, np.vsplit(label_batch, batch_size), gts, infos), key=lambda x: len(x[1]), reverse=True))
|
227 |
+
fc_batch, att_batch, label_batch, gts, infos = \
|
228 |
+
zip(*sorted(zip(fc_batch, att_batch, label_batch, gts, infos), key=lambda x: 0, reverse=True))
|
229 |
+
data = {}
|
230 |
+
data['fc_feats'] = np.stack(fc_batch)
|
231 |
+
# merge att_feats
|
232 |
+
max_att_len = max([_.shape[0] for _ in att_batch])
|
233 |
+
data['att_feats'] = np.zeros([len(att_batch), max_att_len, att_batch[0].shape[1]], dtype = 'float32')
|
234 |
+
for i in range(len(att_batch)):
|
235 |
+
data['att_feats'][i, :att_batch[i].shape[0]] = att_batch[i]
|
236 |
+
data['att_masks'] = np.zeros(data['att_feats'].shape[:2], dtype='float32')
|
237 |
+
for i in range(len(att_batch)):
|
238 |
+
data['att_masks'][i, :att_batch[i].shape[0]] = 1
|
239 |
+
# set att_masks to None if attention features have same length
|
240 |
+
if data['att_masks'].sum() == data['att_masks'].size:
|
241 |
+
data['att_masks'] = None
|
242 |
+
|
243 |
+
data['labels'] = np.vstack(label_batch)
|
244 |
+
# generate mask
|
245 |
+
nonzeros = np.array(list(map(lambda x: (x != 0).sum()+2, data['labels'])))
|
246 |
+
mask_batch = np.zeros([data['labels'].shape[0], self.seq_length + 2], dtype = 'float32')
|
247 |
+
for ix, row in enumerate(mask_batch):
|
248 |
+
row[:nonzeros[ix]] = 1
|
249 |
+
data['masks'] = mask_batch
|
250 |
+
data['labels'] = data['labels'].reshape(len(batch), seq_per_img, -1)
|
251 |
+
data['masks'] = data['masks'].reshape(len(batch), seq_per_img, -1)
|
252 |
+
|
253 |
+
data['gts'] = gts # all ground truth captions of each images
|
254 |
+
data['bounds'] = {'it_pos_now': it_pos_now, # the it_pos_now of the last sample
|
255 |
+
'it_max': len(self.split_ix[split]), 'wrapped': wrapped}
|
256 |
+
data['infos'] = infos
|
257 |
+
|
258 |
+
data = {k:torch.from_numpy(v) if type(v) is np.ndarray else v for k,v in data.items()} # Turn all ndarray to torch tensor
|
259 |
+
|
260 |
+
return data
|
261 |
+
|
262 |
+
def __getitem__(self, index):
|
263 |
+
"""This function returns a tuple that is further passed to collate_fn
|
264 |
+
"""
|
265 |
+
ix, it_pos_now, wrapped = index #self.split_ix[index]
|
266 |
+
if self.use_att:
|
267 |
+
att_feat = self.att_loader.get(str(self.info['images'][ix]['id']))
|
268 |
+
# Reshape to K x C
|
269 |
+
att_feat = att_feat.reshape(-1, att_feat.shape[-1])
|
270 |
+
if self.norm_att_feat:
|
271 |
+
att_feat = att_feat / np.linalg.norm(att_feat, 2, 1, keepdims=True)
|
272 |
+
if self.use_box:
|
273 |
+
box_feat = self.box_loader.get(str(self.info['images'][ix]['id']))
|
274 |
+
# devided by image width and height
|
275 |
+
x1,y1,x2,y2 = np.hsplit(box_feat, 4)
|
276 |
+
h,w = self.info['images'][ix]['height'], self.info['images'][ix]['width']
|
277 |
+
box_feat = np.hstack((x1/w, y1/h, x2/w, y2/h, (x2-x1)*(y2-y1)/(w*h))) # question? x2-x1+1??
|
278 |
+
if self.norm_box_feat:
|
279 |
+
box_feat = box_feat / np.linalg.norm(box_feat, 2, 1, keepdims=True)
|
280 |
+
att_feat = np.hstack([att_feat, box_feat])
|
281 |
+
# sort the features by the size of boxes
|
282 |
+
att_feat = np.stack(sorted(att_feat, key=lambda x:x[-1], reverse=True))
|
283 |
+
else:
|
284 |
+
att_feat = np.zeros((0,0), dtype='float32')
|
285 |
+
if self.use_fc:
|
286 |
+
try:
|
287 |
+
fc_feat = self.fc_loader.get(str(self.info['images'][ix]['id']))
|
288 |
+
except:
|
289 |
+
# Use average of attention when there is no fc provided (For bottomup feature)
|
290 |
+
fc_feat = att_feat.mean(0)
|
291 |
+
else:
|
292 |
+
fc_feat = np.zeros((0), dtype='float32')
|
293 |
+
if hasattr(self, 'h5_label_file'):
|
294 |
+
seq = self.get_captions(ix, self.seq_per_img)
|
295 |
+
else:
|
296 |
+
seq = None
|
297 |
+
return (fc_feat,
|
298 |
+
att_feat, seq,
|
299 |
+
ix, it_pos_now, wrapped)
|
300 |
+
|
301 |
+
def __len__(self):
|
302 |
+
return len(self.info['images'])
|
303 |
+
|
304 |
+
class DataLoader:
|
305 |
+
def __init__(self, opt):
|
306 |
+
self.opt = opt
|
307 |
+
self.batch_size = self.opt.batch_size
|
308 |
+
self.dataset = Dataset(opt)
|
309 |
+
|
310 |
+
# Initialize loaders and iters
|
311 |
+
self.loaders, self.iters = {}, {}
|
312 |
+
for split in ['train', 'val', 'test']:
|
313 |
+
if split == 'train':
|
314 |
+
sampler = MySampler(self.dataset.split_ix[split], shuffle=True, wrap=True)
|
315 |
+
else:
|
316 |
+
sampler = MySampler(self.dataset.split_ix[split], shuffle=False, wrap=False)
|
317 |
+
self.loaders[split] = data.DataLoader(dataset=self.dataset,
|
318 |
+
batch_size=self.batch_size,
|
319 |
+
sampler=sampler,
|
320 |
+
pin_memory=True,
|
321 |
+
num_workers=4, # 4 is usually enough
|
322 |
+
collate_fn=partial(self.dataset.collate_func, split=split),
|
323 |
+
drop_last=False)
|
324 |
+
self.iters[split] = iter(self.loaders[split])
|
325 |
+
|
326 |
+
def get_batch(self, split):
|
327 |
+
try:
|
328 |
+
data = next(self.iters[split])
|
329 |
+
except StopIteration:
|
330 |
+
self.iters[split] = iter(self.loaders[split])
|
331 |
+
data = next(self.iters[split])
|
332 |
+
return data
|
333 |
+
|
334 |
+
def reset_iterator(self, split):
|
335 |
+
self.loaders[split].sampler._reset_iter()
|
336 |
+
self.iters[split] = iter(self.loaders[split])
|
337 |
+
|
338 |
+
def get_vocab_size(self):
|
339 |
+
return self.dataset.get_vocab_size()
|
340 |
+
|
341 |
+
@property
|
342 |
+
def vocab_size(self):
|
343 |
+
return self.get_vocab_size()
|
344 |
+
|
345 |
+
def get_vocab(self):
|
346 |
+
return self.dataset.get_vocab()
|
347 |
+
|
348 |
+
def get_seq_length(self):
|
349 |
+
return self.dataset.get_seq_length()
|
350 |
+
|
351 |
+
@property
|
352 |
+
def seq_length(self):
|
353 |
+
return self.get_seq_length()
|
354 |
+
|
355 |
+
def state_dict(self):
|
356 |
+
def get_prefetch_num(split):
|
357 |
+
if self.loaders[split].num_workers > 0:
|
358 |
+
return (self.iters[split]._send_idx - self.iters[split]._rcvd_idx) * self.batch_size
|
359 |
+
else:
|
360 |
+
return 0
|
361 |
+
return {split: loader.sampler.state_dict(get_prefetch_num(split)) \
|
362 |
+
for split, loader in self.loaders.items()}
|
363 |
+
|
364 |
+
def load_state_dict(self, state_dict=None):
|
365 |
+
if state_dict is None:
|
366 |
+
return
|
367 |
+
for split in self.loaders.keys():
|
368 |
+
self.loaders[split].sampler.load_state_dict(state_dict[split])
|
369 |
+
|
370 |
+
|
371 |
+
class MySampler(data.sampler.Sampler):
|
372 |
+
def __init__(self, index_list, shuffle, wrap):
|
373 |
+
self.index_list = index_list
|
374 |
+
self.shuffle = shuffle
|
375 |
+
self.wrap = wrap
|
376 |
+
# if wrap, there will be not stop iteration called
|
377 |
+
# wrap True used during training, and wrap False used during test.
|
378 |
+
self._reset_iter()
|
379 |
+
|
380 |
+
def __iter__(self):
|
381 |
+
return self
|
382 |
+
|
383 |
+
def __next__(self):
|
384 |
+
wrapped = False
|
385 |
+
if self.iter_counter == len(self._index_list):
|
386 |
+
self._reset_iter()
|
387 |
+
if self.wrap:
|
388 |
+
wrapped = True
|
389 |
+
else:
|
390 |
+
raise StopIteration()
|
391 |
+
if len(self._index_list) == 0: # overflow when 0 samples
|
392 |
+
return None
|
393 |
+
elem = (self._index_list[self.iter_counter], self.iter_counter+1, wrapped)
|
394 |
+
self.iter_counter += 1
|
395 |
+
return elem
|
396 |
+
|
397 |
+
def next(self):
|
398 |
+
return self.__next__()
|
399 |
+
|
400 |
+
def _reset_iter(self):
|
401 |
+
if self.shuffle:
|
402 |
+
rand_perm = npr.permutation(len(self.index_list))
|
403 |
+
self._index_list = [self.index_list[_] for _ in rand_perm]
|
404 |
+
else:
|
405 |
+
self._index_list = self.index_list
|
406 |
+
|
407 |
+
self.iter_counter = 0
|
408 |
+
|
409 |
+
def __len__(self):
|
410 |
+
return len(self.index_list)
|
411 |
+
|
412 |
+
def load_state_dict(self, state_dict=None):
|
413 |
+
if state_dict is None:
|
414 |
+
return
|
415 |
+
self._index_list = state_dict['index_list']
|
416 |
+
self.iter_counter = state_dict['iter_counter']
|
417 |
+
|
418 |
+
def state_dict(self, prefetched_num=None):
|
419 |
+
prefetched_num = prefetched_num or 0
|
420 |
+
return {
|
421 |
+
'index_list': self._index_list,
|
422 |
+
'iter_counter': self.iter_counter - prefetched_num
|
423 |
+
}
|
424 |
+
|
425 |
+
|
captioning/data/pth_loader.py
ADDED
@@ -0,0 +1,334 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import absolute_import
|
2 |
+
from __future__ import division
|
3 |
+
from __future__ import print_function
|
4 |
+
|
5 |
+
import json
|
6 |
+
import h5py
|
7 |
+
from lmdbdict import lmdbdict
|
8 |
+
from lmdbdict.methods import DUMPS_FUNC, LOADS_FUNC
|
9 |
+
import os
|
10 |
+
import numpy as np
|
11 |
+
import numpy.random as npr
|
12 |
+
import random
|
13 |
+
|
14 |
+
import torch
|
15 |
+
import torch.utils.data as data
|
16 |
+
|
17 |
+
import multiprocessing
|
18 |
+
import six
|
19 |
+
|
20 |
+
verbose = True
|
21 |
+
# import torch
|
22 |
+
# if torch.cuda.current_device() in [0, -1]:
|
23 |
+
if 'LOCAL_RANK' in os.environ and os.environ['LOCAL_RANK'] != '0':
|
24 |
+
verbose = False
|
25 |
+
|
26 |
+
class HybridLoader:
|
27 |
+
"""
|
28 |
+
If db_path is a director, then use normal file loading
|
29 |
+
If lmdb, then load from lmdb
|
30 |
+
The loading method depend on extention.
|
31 |
+
|
32 |
+
in_memory: if in_memory is True, we save all the features in memory
|
33 |
+
For individual np(y|z)s, we don't need to do that because the system will do this for us.
|
34 |
+
Should be useful for lmdb or h5.
|
35 |
+
(Copied this idea from vilbert)
|
36 |
+
"""
|
37 |
+
def __init__(self, db_path, ext, in_memory=False):
|
38 |
+
self.db_path = db_path
|
39 |
+
self.ext = ext
|
40 |
+
if self.ext == '.npy':
|
41 |
+
self.loader = lambda x: np.load(six.BytesIO(x))
|
42 |
+
else:
|
43 |
+
self.loader = lambda x: np.load(six.BytesIO(x))['feat']
|
44 |
+
if db_path.endswith('.lmdb'):
|
45 |
+
self.db_type = 'lmdb'
|
46 |
+
self.lmdb = lmdbdict(db_path, unsafe=True)
|
47 |
+
self.lmdb._key_dumps = DUMPS_FUNC['ascii']
|
48 |
+
self.lmdb._value_loads = LOADS_FUNC['identity']
|
49 |
+
elif db_path.endswith('.pth'): # Assume a key,value dictionary
|
50 |
+
self.db_type = 'pth'
|
51 |
+
self.feat_file = torch.load(db_path)
|
52 |
+
self.loader = lambda x: x
|
53 |
+
print('HybridLoader: ext is ignored')
|
54 |
+
elif db_path.endswith('h5'):
|
55 |
+
self.db_type = 'h5'
|
56 |
+
self.loader = lambda x: np.array(x).astype('float32')
|
57 |
+
else:
|
58 |
+
self.db_type = 'dir'
|
59 |
+
|
60 |
+
self.in_memory = in_memory
|
61 |
+
if self.in_memory:
|
62 |
+
self.features = {}
|
63 |
+
|
64 |
+
def get(self, key):
|
65 |
+
|
66 |
+
if self.in_memory and key in self.features:
|
67 |
+
# We save f_input because we want to save the
|
68 |
+
# compressed bytes to save memory
|
69 |
+
f_input = self.features[key]
|
70 |
+
elif self.db_type == 'lmdb':
|
71 |
+
f_input = self.lmdb[key]
|
72 |
+
elif self.db_type == 'pth':
|
73 |
+
f_input = self.feat_file[key]
|
74 |
+
elif self.db_type == 'h5':
|
75 |
+
f_input = h5py.File(self.db_path, 'r')[key]
|
76 |
+
else:
|
77 |
+
f_input = open(os.path.join(self.db_path, key + self.ext), 'rb').read()
|
78 |
+
|
79 |
+
if self.in_memory and key not in self.features:
|
80 |
+
self.features[key] = f_input
|
81 |
+
|
82 |
+
# load image
|
83 |
+
feat = self.loader(f_input)
|
84 |
+
|
85 |
+
return feat
|
86 |
+
|
87 |
+
class CaptionDataset(data.Dataset):
|
88 |
+
|
89 |
+
def get_vocab_size(self):
|
90 |
+
return self.vocab_size
|
91 |
+
|
92 |
+
def get_vocab(self):
|
93 |
+
return self.ix_to_word
|
94 |
+
|
95 |
+
def get_seq_length(self):
|
96 |
+
return self.seq_length
|
97 |
+
|
98 |
+
def __init__(self, opt):
|
99 |
+
self.opt = opt
|
100 |
+
self.seq_per_img = opt.seq_per_img
|
101 |
+
|
102 |
+
# feature related options
|
103 |
+
self.use_fc = getattr(opt, 'use_fc', True)
|
104 |
+
self.use_att = getattr(opt, 'use_att', True)
|
105 |
+
self.use_box = getattr(opt, 'use_box', 0)
|
106 |
+
self.norm_att_feat = getattr(opt, 'norm_att_feat', 0)
|
107 |
+
self.norm_box_feat = getattr(opt, 'norm_box_feat', 0)
|
108 |
+
|
109 |
+
# load the json file which contains additional information about the dataset
|
110 |
+
if verbose:
|
111 |
+
print('DataLoader loading json file: ', opt.input_json)
|
112 |
+
self.info = json.load(open(self.opt.input_json))
|
113 |
+
if 'ix_to_word' in self.info:
|
114 |
+
self.ix_to_word = self.info['ix_to_word']
|
115 |
+
self.vocab_size = len(self.ix_to_word)
|
116 |
+
if verbose:
|
117 |
+
print('vocab size is ', self.vocab_size)
|
118 |
+
|
119 |
+
# open the hdf5 file
|
120 |
+
if verbose:
|
121 |
+
print('DataLoader loading h5 file: ', opt.input_fc_dir, opt.input_att_dir, opt.input_box_dir, opt.input_label_h5)
|
122 |
+
"""
|
123 |
+
Setting input_label_h5 to none is used when only doing generation.
|
124 |
+
For example, when you need to test on coco test set.
|
125 |
+
"""
|
126 |
+
if self.opt.input_label_h5 != 'none':
|
127 |
+
self.h5_label_file = h5py.File(self.opt.input_label_h5, 'r', driver='core')
|
128 |
+
# load in the sequence data
|
129 |
+
seq_size = self.h5_label_file['labels'].shape
|
130 |
+
self.label = self.h5_label_file['labels'][:]
|
131 |
+
self.seq_length = seq_size[1]
|
132 |
+
if verbose:
|
133 |
+
print('max sequence length in data is', self.seq_length)
|
134 |
+
# load the pointers in full to RAM (should be small enough)
|
135 |
+
self.label_start_ix = self.h5_label_file['label_start_ix'][:]
|
136 |
+
self.label_end_ix = self.h5_label_file['label_end_ix'][:]
|
137 |
+
else:
|
138 |
+
self.seq_length = 1
|
139 |
+
|
140 |
+
self.data_in_memory = getattr(opt, 'data_in_memory', False)
|
141 |
+
self.fc_loader = HybridLoader(self.opt.input_fc_dir, '.npy', in_memory=self.data_in_memory)
|
142 |
+
self.att_loader = HybridLoader(self.opt.input_att_dir, '.npz', in_memory=self.data_in_memory)
|
143 |
+
self.box_loader = HybridLoader(self.opt.input_box_dir, '.npy', in_memory=self.data_in_memory)
|
144 |
+
|
145 |
+
self.use_clipscore = getattr(opt, 'use_clipscore', False)
|
146 |
+
# if self.use_clipscore:
|
147 |
+
self.clipscore_loader = HybridLoader(self.opt.input_clipscore_vis_dir, '.npy', in_memory=self.data_in_memory)
|
148 |
+
|
149 |
+
|
150 |
+
self.num_images = len(self.info['images']) # self.label_start_ix.shape[0]
|
151 |
+
if verbose:
|
152 |
+
print('read %d image features' %(self.num_images))
|
153 |
+
|
154 |
+
# separate out indexes for each of the provided splits
|
155 |
+
self.split_ix = {'train': [], 'val': [], 'test': []}
|
156 |
+
for ix in range(len(self.info['images'])):
|
157 |
+
img = self.info['images'][ix]
|
158 |
+
if not 'split' in img:
|
159 |
+
self.split_ix['train'].append(ix)
|
160 |
+
self.split_ix['val'].append(ix)
|
161 |
+
self.split_ix['test'].append(ix)
|
162 |
+
elif img['split'] == 'train':
|
163 |
+
self.split_ix['train'].append(ix)
|
164 |
+
elif img['split'] == 'val':
|
165 |
+
self.split_ix['val'].append(ix)
|
166 |
+
elif img['split'] == 'test':
|
167 |
+
self.split_ix['test'].append(ix)
|
168 |
+
elif opt.train_only == 0: # restval
|
169 |
+
self.split_ix['train'].append(ix)
|
170 |
+
|
171 |
+
if verbose:
|
172 |
+
print('assigned %d images to split train' %len(self.split_ix['train']))
|
173 |
+
print('assigned %d images to split val' %len(self.split_ix['val']))
|
174 |
+
print('assigned %d images to split test' %len(self.split_ix['test']))
|
175 |
+
|
176 |
+
def get_captions(self, ix, seq_per_img):
|
177 |
+
# fetch the sequence labels
|
178 |
+
ix1 = self.label_start_ix[ix] - 1 #label_start_ix starts from 1
|
179 |
+
ix2 = self.label_end_ix[ix] - 1
|
180 |
+
ncap = ix2 - ix1 + 1 # number of captions available for this image
|
181 |
+
assert ncap > 0, 'an image does not have any label. this can be handled but right now isn\'t'
|
182 |
+
|
183 |
+
if ncap < seq_per_img:
|
184 |
+
# we need to subsample (with replacement)
|
185 |
+
seq = np.zeros([seq_per_img, self.seq_length], dtype = 'int')
|
186 |
+
for q in range(seq_per_img):
|
187 |
+
ixl = random.randint(ix1,ix2)
|
188 |
+
seq[q, :] = self.label[ixl, :self.seq_length]
|
189 |
+
else:
|
190 |
+
ixl = random.randint(ix1, ix2 - seq_per_img + 1)
|
191 |
+
seq = self.label[ixl: ixl + seq_per_img, :self.seq_length]
|
192 |
+
|
193 |
+
return seq
|
194 |
+
|
195 |
+
def collate_func(self, batch):
|
196 |
+
seq_per_img = self.seq_per_img
|
197 |
+
|
198 |
+
fc_batch = []
|
199 |
+
att_batch = []
|
200 |
+
label_batch = []
|
201 |
+
|
202 |
+
clip_vis_feat_batch = []
|
203 |
+
|
204 |
+
wrapped = False
|
205 |
+
|
206 |
+
infos = []
|
207 |
+
gts = []
|
208 |
+
|
209 |
+
for sample in batch:
|
210 |
+
# fetch image
|
211 |
+
# if self.use_clipscore:
|
212 |
+
tmp_fc, tmp_att, tmp_seq, \
|
213 |
+
ix, tmp_clip_vis_feat = sample
|
214 |
+
|
215 |
+
clip_vis_feat_batch.append(tmp_clip_vis_feat)
|
216 |
+
# else:
|
217 |
+
# tmp_fc, tmp_att, tmp_seq, \
|
218 |
+
# ix = sample
|
219 |
+
|
220 |
+
fc_batch.append(tmp_fc)
|
221 |
+
att_batch.append(tmp_att)
|
222 |
+
|
223 |
+
tmp_label = np.zeros([seq_per_img, self.seq_length + 2], dtype = 'int')
|
224 |
+
if hasattr(self, 'h5_label_file'):
|
225 |
+
# if there is ground truth
|
226 |
+
tmp_label[:, 1 : self.seq_length + 1] = tmp_seq
|
227 |
+
label_batch.append(tmp_label)
|
228 |
+
|
229 |
+
# Used for reward evaluation
|
230 |
+
if hasattr(self, 'h5_label_file'):
|
231 |
+
# if there is ground truth
|
232 |
+
gts.append(self.label[self.label_start_ix[ix] - 1: self.label_end_ix[ix]])
|
233 |
+
else:
|
234 |
+
gts.append([])
|
235 |
+
|
236 |
+
# record associated info as well
|
237 |
+
info_dict = {}
|
238 |
+
info_dict['ix'] = ix
|
239 |
+
info_dict['id'] = self.info['images'][ix]['id']
|
240 |
+
info_dict['file_path'] = self.info['images'][ix].get('file_path', '')
|
241 |
+
infos.append(info_dict)
|
242 |
+
|
243 |
+
# #sort by att_feat length
|
244 |
+
# fc_batch, att_batch, label_batch, gts, infos = \
|
245 |
+
# zip(*sorted(zip(fc_batch, att_batch, np.vsplit(label_batch, batch_size), gts, infos), key=lambda x: len(x[1]), reverse=True))
|
246 |
+
if self.use_clipscore:
|
247 |
+
fc_batch, att_batch, label_batch, clip_vis_feat_batch, gts, infos = \
|
248 |
+
zip(*sorted(zip(fc_batch, att_batch, label_batch, clip_vis_feat_batch, gts, infos), key=lambda x: 0, reverse=True))
|
249 |
+
else:
|
250 |
+
fc_batch, att_batch, label_batch, gts, infos = \
|
251 |
+
zip(*sorted(zip(fc_batch, att_batch, label_batch, gts, infos), key=lambda x: 0, reverse=True))
|
252 |
+
data = {}
|
253 |
+
data['fc_feats'] = np.stack(fc_batch)
|
254 |
+
# merge att_feats
|
255 |
+
max_att_len = max([_.shape[0] for _ in att_batch])
|
256 |
+
data['att_feats'] = np.zeros([len(att_batch), max_att_len, att_batch[0].shape[1]], dtype = 'float32')
|
257 |
+
for i in range(len(att_batch)):
|
258 |
+
data['att_feats'][i, :att_batch[i].shape[0]] = att_batch[i]
|
259 |
+
data['att_masks'] = np.zeros(data['att_feats'].shape[:2], dtype='float32')
|
260 |
+
for i in range(len(att_batch)):
|
261 |
+
data['att_masks'][i, :att_batch[i].shape[0]] = 1
|
262 |
+
# set att_masks to None if attention features have same length
|
263 |
+
if data['att_masks'].sum() == data['att_masks'].size:
|
264 |
+
data['att_masks'] = None
|
265 |
+
|
266 |
+
# if self.use_clipscore:
|
267 |
+
data['clip_vis_feats'] = np.stack(clip_vis_feat_batch)
|
268 |
+
|
269 |
+
data['labels'] = np.vstack(label_batch)
|
270 |
+
# generate mask
|
271 |
+
nonzeros = np.array(list(map(lambda x: (x != 0).sum()+2, data['labels'])))
|
272 |
+
mask_batch = np.zeros([data['labels'].shape[0], self.seq_length + 2], dtype = 'float32')
|
273 |
+
for ix, row in enumerate(mask_batch):
|
274 |
+
row[:nonzeros[ix]] = 1
|
275 |
+
data['masks'] = mask_batch
|
276 |
+
data['labels'] = data['labels'].reshape(len(batch), seq_per_img, -1)
|
277 |
+
data['masks'] = data['masks'].reshape(len(batch), seq_per_img, -1)
|
278 |
+
|
279 |
+
data['gts'] = gts # all ground truth captions of each images
|
280 |
+
data['infos'] = infos
|
281 |
+
|
282 |
+
data = {k:torch.from_numpy(v) if type(v) is np.ndarray else v for k,v in data.items()} # Turn all ndarray to torch tensor
|
283 |
+
|
284 |
+
return data
|
285 |
+
|
286 |
+
def __getitem__(self, ix):
|
287 |
+
"""This function returns a tuple that is further passed to collate_fn
|
288 |
+
"""
|
289 |
+
if self.use_att:
|
290 |
+
att_feat = self.att_loader.get(str(self.info['images'][ix]['id']))
|
291 |
+
# Reshape to K x C
|
292 |
+
att_feat = att_feat.reshape(-1, att_feat.shape[-1])
|
293 |
+
if self.norm_att_feat:
|
294 |
+
att_feat = att_feat / np.linalg.norm(att_feat, 2, 1, keepdims=True)
|
295 |
+
if self.use_box:
|
296 |
+
box_feat = self.box_loader.get(str(self.info['images'][ix]['id']))
|
297 |
+
# devided by image width and height
|
298 |
+
x1,y1,x2,y2 = np.hsplit(box_feat, 4)
|
299 |
+
h,w = self.info['images'][ix]['height'], self.info['images'][ix]['width']
|
300 |
+
box_feat = np.hstack((x1/w, y1/h, x2/w, y2/h, (x2-x1)*(y2-y1)/(w*h))) # question? x2-x1+1??
|
301 |
+
if self.norm_box_feat:
|
302 |
+
box_feat = box_feat / np.linalg.norm(box_feat, 2, 1, keepdims=True)
|
303 |
+
att_feat = np.hstack([att_feat, box_feat])
|
304 |
+
# sort the features by the size of boxes
|
305 |
+
att_feat = np.stack(sorted(att_feat, key=lambda x:x[-1], reverse=True))
|
306 |
+
else:
|
307 |
+
att_feat = np.zeros((0,0), dtype='float32')
|
308 |
+
if self.use_fc:
|
309 |
+
try:
|
310 |
+
fc_feat = self.fc_loader.get(str(self.info['images'][ix]['id']))
|
311 |
+
except:
|
312 |
+
# Use average of attention when there is no fc provided (For bottomup feature)
|
313 |
+
fc_feat = att_feat.mean(0)
|
314 |
+
else:
|
315 |
+
fc_feat = np.zeros((0), dtype='float32')
|
316 |
+
if hasattr(self, 'h5_label_file'):
|
317 |
+
seq = self.get_captions(ix, self.seq_per_img)
|
318 |
+
else:
|
319 |
+
seq = None
|
320 |
+
|
321 |
+
# if self.use_clipscore:
|
322 |
+
clip_vis_feat = self.clipscore_loader.get(
|
323 |
+
str(self.info['images'][ix]['id']))
|
324 |
+
|
325 |
+
return (fc_feat,
|
326 |
+
att_feat, seq,
|
327 |
+
ix, clip_vis_feat)
|
328 |
+
|
329 |
+
# return (fc_feat,
|
330 |
+
# att_feat, seq,
|
331 |
+
# ix)
|
332 |
+
|
333 |
+
def __len__(self):
|
334 |
+
return len(self.info['images'])
|
captioning/data/pth_loader_FineCapEval.py
ADDED
@@ -0,0 +1,334 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import absolute_import
|
2 |
+
from __future__ import division
|
3 |
+
from __future__ import print_function
|
4 |
+
|
5 |
+
import json
|
6 |
+
import h5py
|
7 |
+
from lmdbdict import lmdbdict
|
8 |
+
from lmdbdict.methods import DUMPS_FUNC, LOADS_FUNC
|
9 |
+
import os
|
10 |
+
import numpy as np
|
11 |
+
import numpy.random as npr
|
12 |
+
import random
|
13 |
+
|
14 |
+
import torch
|
15 |
+
import torch.utils.data as data
|
16 |
+
|
17 |
+
import multiprocessing
|
18 |
+
import six
|
19 |
+
|
20 |
+
verbose = True
|
21 |
+
# import torch
|
22 |
+
# if torch.cuda.current_device() in [0, -1]:
|
23 |
+
if 'LOCAL_RANK' in os.environ and os.environ['LOCAL_RANK'] != '0':
|
24 |
+
verbose = False
|
25 |
+
|
26 |
+
class HybridLoader:
|
27 |
+
"""
|
28 |
+
If db_path is a director, then use normal file loading
|
29 |
+
If lmdb, then load from lmdb
|
30 |
+
The loading method depend on extention.
|
31 |
+
|
32 |
+
in_memory: if in_memory is True, we save all the features in memory
|
33 |
+
For individual np(y|z)s, we don't need to do that because the system will do this for us.
|
34 |
+
Should be useful for lmdb or h5.
|
35 |
+
(Copied this idea from vilbert)
|
36 |
+
"""
|
37 |
+
def __init__(self, db_path, ext, in_memory=False):
|
38 |
+
self.db_path = db_path
|
39 |
+
self.ext = ext
|
40 |
+
if self.ext == '.npy':
|
41 |
+
self.loader = lambda x: np.load(six.BytesIO(x))
|
42 |
+
else:
|
43 |
+
self.loader = lambda x: np.load(six.BytesIO(x))['feat']
|
44 |
+
if db_path.endswith('.lmdb'):
|
45 |
+
self.db_type = 'lmdb'
|
46 |
+
self.lmdb = lmdbdict(db_path, unsafe=True)
|
47 |
+
self.lmdb._key_dumps = DUMPS_FUNC['ascii']
|
48 |
+
self.lmdb._value_loads = LOADS_FUNC['identity']
|
49 |
+
elif db_path.endswith('.pth'): # Assume a key,value dictionary
|
50 |
+
self.db_type = 'pth'
|
51 |
+
self.feat_file = torch.load(db_path)
|
52 |
+
self.loader = lambda x: x
|
53 |
+
print('HybridLoader: ext is ignored')
|
54 |
+
elif db_path.endswith('h5'):
|
55 |
+
self.db_type = 'h5'
|
56 |
+
self.loader = lambda x: np.array(x).astype('float32')
|
57 |
+
else:
|
58 |
+
self.db_type = 'dir'
|
59 |
+
|
60 |
+
self.in_memory = in_memory
|
61 |
+
if self.in_memory:
|
62 |
+
self.features = {}
|
63 |
+
|
64 |
+
def get(self, key):
|
65 |
+
|
66 |
+
if self.in_memory and key in self.features:
|
67 |
+
# We save f_input because we want to save the
|
68 |
+
# compressed bytes to save memory
|
69 |
+
f_input = self.features[key]
|
70 |
+
elif self.db_type == 'lmdb':
|
71 |
+
f_input = self.lmdb[key]
|
72 |
+
elif self.db_type == 'pth':
|
73 |
+
f_input = self.feat_file[key]
|
74 |
+
elif self.db_type == 'h5':
|
75 |
+
f_input = h5py.File(self.db_path, 'r')[key]
|
76 |
+
else:
|
77 |
+
f_input = open(os.path.join(self.db_path, key + self.ext), 'rb').read()
|
78 |
+
|
79 |
+
if self.in_memory and key not in self.features:
|
80 |
+
self.features[key] = f_input
|
81 |
+
|
82 |
+
# load image
|
83 |
+
feat = self.loader(f_input)
|
84 |
+
|
85 |
+
return feat
|
86 |
+
|
87 |
+
class CaptionDataset(data.Dataset):
|
88 |
+
|
89 |
+
def get_vocab_size(self):
|
90 |
+
return self.vocab_size
|
91 |
+
|
92 |
+
def get_vocab(self):
|
93 |
+
return self.ix_to_word
|
94 |
+
|
95 |
+
def get_seq_length(self):
|
96 |
+
return self.seq_length
|
97 |
+
|
98 |
+
def __init__(self, opt):
|
99 |
+
self.opt = opt
|
100 |
+
self.seq_per_img = opt.seq_per_img
|
101 |
+
|
102 |
+
# feature related options
|
103 |
+
self.use_fc = getattr(opt, 'use_fc', True)
|
104 |
+
self.use_att = getattr(opt, 'use_att', True)
|
105 |
+
self.use_box = getattr(opt, 'use_box', 0)
|
106 |
+
self.norm_att_feat = getattr(opt, 'norm_att_feat', 0)
|
107 |
+
self.norm_box_feat = getattr(opt, 'norm_box_feat', 0)
|
108 |
+
|
109 |
+
# load the json file which contains additional information about the dataset
|
110 |
+
if verbose:
|
111 |
+
print('DataLoader loading json file: ', opt.input_json)
|
112 |
+
self.info = json.load(open(self.opt.input_json))
|
113 |
+
if 'ix_to_word' in self.info:
|
114 |
+
self.ix_to_word = self.info['ix_to_word']
|
115 |
+
self.vocab_size = len(self.ix_to_word)
|
116 |
+
if verbose:
|
117 |
+
print('vocab size is ', self.vocab_size)
|
118 |
+
|
119 |
+
# open the hdf5 file
|
120 |
+
if verbose:
|
121 |
+
print('DataLoader loading h5 file: ', opt.input_fc_dir, opt.input_att_dir, opt.input_box_dir, opt.input_label_h5)
|
122 |
+
"""
|
123 |
+
Setting input_label_h5 to none is used when only doing generation.
|
124 |
+
For example, when you need to test on coco test set.
|
125 |
+
"""
|
126 |
+
if self.opt.input_label_h5 != 'none':
|
127 |
+
self.h5_label_file = h5py.File(self.opt.input_label_h5, 'r', driver='core')
|
128 |
+
# load in the sequence data
|
129 |
+
seq_size = self.h5_label_file['labels'].shape
|
130 |
+
self.label = self.h5_label_file['labels'][:]
|
131 |
+
self.seq_length = seq_size[1]
|
132 |
+
if verbose:
|
133 |
+
print('max sequence length in data is', self.seq_length)
|
134 |
+
# load the pointers in full to RAM (should be small enough)
|
135 |
+
self.label_start_ix = self.h5_label_file['label_start_ix'][:]
|
136 |
+
self.label_end_ix = self.h5_label_file['label_end_ix'][:]
|
137 |
+
else:
|
138 |
+
self.seq_length = 1
|
139 |
+
|
140 |
+
self.data_in_memory = getattr(opt, 'data_in_memory', False)
|
141 |
+
self.fc_loader = HybridLoader(self.opt.input_fc_dir, '.npy', in_memory=self.data_in_memory)
|
142 |
+
self.att_loader = HybridLoader(self.opt.input_att_dir, '.npz', in_memory=self.data_in_memory)
|
143 |
+
self.box_loader = HybridLoader(self.opt.input_box_dir, '.npy', in_memory=self.data_in_memory)
|
144 |
+
|
145 |
+
self.use_clipscore = getattr(opt, 'use_clipscore', False)
|
146 |
+
if self.use_clipscore:
|
147 |
+
self.clipscore_loader = HybridLoader(self.opt.input_clipscore_vis_dir, '.npy', in_memory=self.data_in_memory)
|
148 |
+
|
149 |
+
|
150 |
+
self.num_images = len(self.info['images']) # self.label_start_ix.shape[0]
|
151 |
+
if verbose:
|
152 |
+
print('read %d image features' %(self.num_images))
|
153 |
+
|
154 |
+
# separate out indexes for each of the provided splits
|
155 |
+
self.split_ix = {'train': [], 'val': [], 'test': []}
|
156 |
+
for ix in range(len(self.info['images'])):
|
157 |
+
img = self.info['images'][ix]
|
158 |
+
if not 'split' in img:
|
159 |
+
self.split_ix['train'].append(ix)
|
160 |
+
self.split_ix['val'].append(ix)
|
161 |
+
self.split_ix['test'].append(ix)
|
162 |
+
elif img['split'] == 'train':
|
163 |
+
self.split_ix['train'].append(ix)
|
164 |
+
elif img['split'] == 'val':
|
165 |
+
self.split_ix['val'].append(ix)
|
166 |
+
elif img['split'] == 'test':
|
167 |
+
self.split_ix['test'].append(ix)
|
168 |
+
elif opt.train_only == 0: # restval
|
169 |
+
self.split_ix['train'].append(ix)
|
170 |
+
|
171 |
+
if verbose:
|
172 |
+
print('assigned %d images to split train' %len(self.split_ix['train']))
|
173 |
+
print('assigned %d images to split val' %len(self.split_ix['val']))
|
174 |
+
print('assigned %d images to split test' %len(self.split_ix['test']))
|
175 |
+
|
176 |
+
def get_captions(self, ix, seq_per_img):
|
177 |
+
# fetch the sequence labels
|
178 |
+
ix1 = self.label_start_ix[ix] - 1 #label_start_ix starts from 1
|
179 |
+
ix2 = self.label_end_ix[ix] - 1
|
180 |
+
ncap = ix2 - ix1 + 1 # number of captions available for this image
|
181 |
+
assert ncap > 0, 'an image does not have any label. this can be handled but right now isn\'t'
|
182 |
+
|
183 |
+
if ncap < seq_per_img:
|
184 |
+
# we need to subsample (with replacement)
|
185 |
+
seq = np.zeros([seq_per_img, self.seq_length], dtype = 'int')
|
186 |
+
for q in range(seq_per_img):
|
187 |
+
ixl = random.randint(ix1,ix2)
|
188 |
+
seq[q, :] = self.label[ixl, :self.seq_length]
|
189 |
+
else:
|
190 |
+
ixl = random.randint(ix1, ix2 - seq_per_img + 1)
|
191 |
+
seq = self.label[ixl: ixl + seq_per_img, :self.seq_length]
|
192 |
+
|
193 |
+
return seq
|
194 |
+
|
195 |
+
def collate_func(self, batch):
|
196 |
+
seq_per_img = self.seq_per_img
|
197 |
+
|
198 |
+
fc_batch = []
|
199 |
+
att_batch = []
|
200 |
+
label_batch = []
|
201 |
+
|
202 |
+
clip_vis_feat_batch = []
|
203 |
+
|
204 |
+
wrapped = False
|
205 |
+
|
206 |
+
infos = []
|
207 |
+
gts = []
|
208 |
+
|
209 |
+
for sample in batch:
|
210 |
+
# fetch image
|
211 |
+
if self.use_clipscore:
|
212 |
+
tmp_fc, tmp_att, tmp_seq, \
|
213 |
+
ix, tmp_clip_vis_feat = sample
|
214 |
+
|
215 |
+
clip_vis_feat_batch.append(tmp_clip_vis_feat)
|
216 |
+
else:
|
217 |
+
tmp_fc, tmp_att, tmp_seq, \
|
218 |
+
ix = sample
|
219 |
+
|
220 |
+
fc_batch.append(tmp_fc)
|
221 |
+
att_batch.append(tmp_att)
|
222 |
+
|
223 |
+
tmp_label = np.zeros([seq_per_img, self.seq_length + 2], dtype = 'int')
|
224 |
+
if hasattr(self, 'h5_label_file'):
|
225 |
+
# if there is ground truth
|
226 |
+
tmp_label[:, 1 : self.seq_length + 1] = tmp_seq
|
227 |
+
label_batch.append(tmp_label)
|
228 |
+
|
229 |
+
# Used for reward evaluation
|
230 |
+
if hasattr(self, 'h5_label_file'):
|
231 |
+
# if there is ground truth
|
232 |
+
gts.append(self.label[self.label_start_ix[ix] - 1: self.label_end_ix[ix]])
|
233 |
+
else:
|
234 |
+
gts.append([])
|
235 |
+
|
236 |
+
# record associated info as well
|
237 |
+
info_dict = {}
|
238 |
+
info_dict['ix'] = ix
|
239 |
+
info_dict['id'] = self.info['images'][ix]['id']
|
240 |
+
info_dict['file_path'] = self.info['images'][ix].get('file_path', '')
|
241 |
+
infos.append(info_dict)
|
242 |
+
|
243 |
+
# #sort by att_feat length
|
244 |
+
# fc_batch, att_batch, label_batch, gts, infos = \
|
245 |
+
# zip(*sorted(zip(fc_batch, att_batch, np.vsplit(label_batch, batch_size), gts, infos), key=lambda x: len(x[1]), reverse=True))
|
246 |
+
if self.use_clipscore:
|
247 |
+
fc_batch, att_batch, label_batch, clip_vis_feat_batch, gts, infos = \
|
248 |
+
zip(*sorted(zip(fc_batch, att_batch, label_batch, clip_vis_feat_batch, gts, infos), key=lambda x: 0, reverse=True))
|
249 |
+
else:
|
250 |
+
fc_batch, att_batch, label_batch, gts, infos = \
|
251 |
+
zip(*sorted(zip(fc_batch, att_batch, label_batch, gts, infos), key=lambda x: 0, reverse=True))
|
252 |
+
data = {}
|
253 |
+
data['fc_feats'] = np.stack(fc_batch)
|
254 |
+
# merge att_feats
|
255 |
+
max_att_len = max([_.shape[0] for _ in att_batch])
|
256 |
+
data['att_feats'] = np.zeros([len(att_batch), max_att_len, att_batch[0].shape[1]], dtype = 'float32')
|
257 |
+
for i in range(len(att_batch)):
|
258 |
+
data['att_feats'][i, :att_batch[i].shape[0]] = att_batch[i]
|
259 |
+
data['att_masks'] = np.zeros(data['att_feats'].shape[:2], dtype='float32')
|
260 |
+
for i in range(len(att_batch)):
|
261 |
+
data['att_masks'][i, :att_batch[i].shape[0]] = 1
|
262 |
+
# set att_masks to None if attention features have same length
|
263 |
+
if data['att_masks'].sum() == data['att_masks'].size:
|
264 |
+
data['att_masks'] = None
|
265 |
+
|
266 |
+
if self.use_clipscore:
|
267 |
+
data['clip_vis_feats'] = np.stack(clip_vis_feat_batch)
|
268 |
+
|
269 |
+
data['labels'] = np.vstack(label_batch)
|
270 |
+
# generate mask
|
271 |
+
nonzeros = np.array(list(map(lambda x: (x != 0).sum()+2, data['labels'])))
|
272 |
+
mask_batch = np.zeros([data['labels'].shape[0], self.seq_length + 2], dtype = 'float32')
|
273 |
+
for ix, row in enumerate(mask_batch):
|
274 |
+
row[:nonzeros[ix]] = 1
|
275 |
+
data['masks'] = mask_batch
|
276 |
+
data['labels'] = data['labels'].reshape(len(batch), seq_per_img, -1)
|
277 |
+
data['masks'] = data['masks'].reshape(len(batch), seq_per_img, -1)
|
278 |
+
|
279 |
+
data['gts'] = gts # all ground truth captions of each images
|
280 |
+
data['infos'] = infos
|
281 |
+
|
282 |
+
data = {k:torch.from_numpy(v) if type(v) is np.ndarray else v for k,v in data.items()} # Turn all ndarray to torch tensor
|
283 |
+
|
284 |
+
return data
|
285 |
+
|
286 |
+
def __getitem__(self, ix):
|
287 |
+
"""This function returns a tuple that is further passed to collate_fn
|
288 |
+
"""
|
289 |
+
if self.use_att:
|
290 |
+
att_feat = self.att_loader.get(str(self.info['images'][ix]['id']))
|
291 |
+
# Reshape to K x C
|
292 |
+
att_feat = att_feat.reshape(-1, att_feat.shape[-1])
|
293 |
+
if self.norm_att_feat:
|
294 |
+
att_feat = att_feat / np.linalg.norm(att_feat, 2, 1, keepdims=True)
|
295 |
+
if self.use_box:
|
296 |
+
box_feat = self.box_loader.get(str(self.info['images'][ix]['id']))
|
297 |
+
# devided by image width and height
|
298 |
+
x1,y1,x2,y2 = np.hsplit(box_feat, 4)
|
299 |
+
h,w = self.info['images'][ix]['height'], self.info['images'][ix]['width']
|
300 |
+
box_feat = np.hstack((x1/w, y1/h, x2/w, y2/h, (x2-x1)*(y2-y1)/(w*h))) # question? x2-x1+1??
|
301 |
+
if self.norm_box_feat:
|
302 |
+
box_feat = box_feat / np.linalg.norm(box_feat, 2, 1, keepdims=True)
|
303 |
+
att_feat = np.hstack([att_feat, box_feat])
|
304 |
+
# sort the features by the size of boxes
|
305 |
+
att_feat = np.stack(sorted(att_feat, key=lambda x:x[-1], reverse=True))
|
306 |
+
else:
|
307 |
+
att_feat = np.zeros((0,0), dtype='float32')
|
308 |
+
if self.use_fc:
|
309 |
+
try:
|
310 |
+
fc_feat = self.fc_loader.get(str(self.info['images'][ix]['id']))
|
311 |
+
except:
|
312 |
+
# Use average of attention when there is no fc provided (For bottomup feature)
|
313 |
+
fc_feat = att_feat.mean(0)
|
314 |
+
else:
|
315 |
+
fc_feat = np.zeros((0), dtype='float32')
|
316 |
+
if hasattr(self, 'h5_label_file'):
|
317 |
+
seq = self.get_captions(ix, self.seq_per_img)
|
318 |
+
else:
|
319 |
+
seq = None
|
320 |
+
|
321 |
+
if self.use_clipscore:
|
322 |
+
clip_vis_feat = self.clipscore_loader.get(
|
323 |
+
str(self.info['images'][ix]['id']))
|
324 |
+
|
325 |
+
return (fc_feat,
|
326 |
+
att_feat, seq,
|
327 |
+
ix, clip_vis_feat)
|
328 |
+
|
329 |
+
return (fc_feat,
|
330 |
+
att_feat, seq,
|
331 |
+
ix)
|
332 |
+
|
333 |
+
def __len__(self):
|
334 |
+
return len(self.info['images'])
|
captioning/models/AoAModel.py
ADDED
@@ -0,0 +1,228 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Implementation for paper 'Attention on Attention for Image Captioning'
|
2 |
+
# https://arxiv.org/abs/1908.06954
|
3 |
+
|
4 |
+
# RT: Code from original author's repo: https://github.com/husthuaan/AoANet/
|
5 |
+
|
6 |
+
from __future__ import absolute_import
|
7 |
+
from __future__ import division
|
8 |
+
from __future__ import print_function
|
9 |
+
|
10 |
+
import torch
|
11 |
+
import torch.nn as nn
|
12 |
+
import torch.nn.functional as F
|
13 |
+
|
14 |
+
from .AttModel import pack_wrapper, AttModel, Attention
|
15 |
+
from .TransformerModel import LayerNorm, attention, clones, SublayerConnection, PositionwiseFeedForward
|
16 |
+
|
17 |
+
class MultiHeadedDotAttention(nn.Module):
|
18 |
+
def __init__(self, h, d_model, dropout=0.1, scale=1, project_k_v=1, use_output_layer=1, do_aoa=0, norm_q=0, dropout_aoa=0.3):
|
19 |
+
super(MultiHeadedDotAttention, self).__init__()
|
20 |
+
assert d_model * scale % h == 0
|
21 |
+
# We assume d_v always equals d_k
|
22 |
+
self.d_k = d_model * scale // h
|
23 |
+
self.h = h
|
24 |
+
|
25 |
+
# Do we need to do linear projections on K and V?
|
26 |
+
self.project_k_v = project_k_v
|
27 |
+
|
28 |
+
# normalize the query?
|
29 |
+
if norm_q:
|
30 |
+
self.norm = LayerNorm(d_model)
|
31 |
+
else:
|
32 |
+
self.norm = lambda x:x
|
33 |
+
self.linears = clones(nn.Linear(d_model, d_model * scale), 1 + 2 * project_k_v)
|
34 |
+
|
35 |
+
# output linear layer after the multi-head attention?
|
36 |
+
self.output_layer = nn.Linear(d_model * scale, d_model)
|
37 |
+
|
38 |
+
# apply aoa after attention?
|
39 |
+
self.use_aoa = do_aoa
|
40 |
+
if self.use_aoa:
|
41 |
+
self.aoa_layer = nn.Sequential(nn.Linear((1 + scale) * d_model, 2 * d_model), nn.GLU())
|
42 |
+
# dropout to the input of AoA layer
|
43 |
+
if dropout_aoa > 0:
|
44 |
+
self.dropout_aoa = nn.Dropout(p=dropout_aoa)
|
45 |
+
else:
|
46 |
+
self.dropout_aoa = lambda x:x
|
47 |
+
|
48 |
+
if self.use_aoa or not use_output_layer:
|
49 |
+
# AoA doesn't need the output linear layer
|
50 |
+
del self.output_layer
|
51 |
+
self.output_layer = lambda x:x
|
52 |
+
|
53 |
+
self.attn = None
|
54 |
+
self.dropout = nn.Dropout(p=dropout)
|
55 |
+
|
56 |
+
def forward(self, query, value, key, mask=None):
|
57 |
+
if mask is not None:
|
58 |
+
if len(mask.size()) == 2:
|
59 |
+
mask = mask.unsqueeze(-2)
|
60 |
+
# Same mask applied to all h heads.
|
61 |
+
mask = mask.unsqueeze(1)
|
62 |
+
|
63 |
+
single_query = 0
|
64 |
+
if len(query.size()) == 2:
|
65 |
+
single_query = 1
|
66 |
+
query = query.unsqueeze(1)
|
67 |
+
|
68 |
+
nbatches = query.size(0)
|
69 |
+
|
70 |
+
query = self.norm(query)
|
71 |
+
|
72 |
+
# Do all the linear projections in batch from d_model => h x d_k
|
73 |
+
if self.project_k_v == 0:
|
74 |
+
query_ = self.linears[0](query).view(nbatches, -1, self.h, self.d_k).transpose(1, 2)
|
75 |
+
key_ = key.view(nbatches, -1, self.h, self.d_k).transpose(1, 2)
|
76 |
+
value_ = value.view(nbatches, -1, self.h, self.d_k).transpose(1, 2)
|
77 |
+
else:
|
78 |
+
query_, key_, value_ = \
|
79 |
+
[l(x).view(nbatches, -1, self.h, self.d_k).transpose(1, 2)
|
80 |
+
for l, x in zip(self.linears, (query, key, value))]
|
81 |
+
|
82 |
+
# Apply attention on all the projected vectors in batch.
|
83 |
+
x, self.attn = attention(query_, key_, value_, mask=mask,
|
84 |
+
dropout=self.dropout)
|
85 |
+
|
86 |
+
# "Concat" using a view
|
87 |
+
x = x.transpose(1, 2).contiguous() \
|
88 |
+
.view(nbatches, -1, self.h * self.d_k)
|
89 |
+
|
90 |
+
if self.use_aoa:
|
91 |
+
# Apply AoA
|
92 |
+
x = self.aoa_layer(self.dropout_aoa(torch.cat([x, query], -1)))
|
93 |
+
x = self.output_layer(x)
|
94 |
+
|
95 |
+
if single_query:
|
96 |
+
query = query.squeeze(1)
|
97 |
+
x = x.squeeze(1)
|
98 |
+
return x
|
99 |
+
|
100 |
+
class AoA_Refiner_Layer(nn.Module):
|
101 |
+
def __init__(self, size, self_attn, feed_forward, dropout):
|
102 |
+
super(AoA_Refiner_Layer, self).__init__()
|
103 |
+
self.self_attn = self_attn
|
104 |
+
self.feed_forward = feed_forward
|
105 |
+
self.use_ff = 0
|
106 |
+
if self.feed_forward is not None:
|
107 |
+
self.use_ff = 1
|
108 |
+
self.sublayer = clones(SublayerConnection(size, dropout), 1+self.use_ff)
|
109 |
+
self.size = size
|
110 |
+
|
111 |
+
def forward(self, x, mask):
|
112 |
+
x = self.sublayer[0](x, lambda x: self.self_attn(x, x, x, mask))
|
113 |
+
return self.sublayer[-1](x, self.feed_forward) if self.use_ff else x
|
114 |
+
|
115 |
+
class AoA_Refiner_Core(nn.Module):
|
116 |
+
def __init__(self, opt):
|
117 |
+
super(AoA_Refiner_Core, self).__init__()
|
118 |
+
attn = MultiHeadedDotAttention(opt.num_heads, opt.rnn_size, project_k_v=1, scale=opt.multi_head_scale, do_aoa=opt.refine_aoa, norm_q=0, dropout_aoa=getattr(opt, 'dropout_aoa', 0.3))
|
119 |
+
layer = AoA_Refiner_Layer(opt.rnn_size, attn, PositionwiseFeedForward(opt.rnn_size, 2048, 0.1) if opt.use_ff else None, 0.1)
|
120 |
+
self.layers = clones(layer, 6)
|
121 |
+
self.norm = LayerNorm(layer.size)
|
122 |
+
|
123 |
+
def forward(self, x, mask):
|
124 |
+
for layer in self.layers:
|
125 |
+
x = layer(x, mask)
|
126 |
+
return self.norm(x)
|
127 |
+
|
128 |
+
class AoA_Decoder_Core(nn.Module):
|
129 |
+
def __init__(self, opt):
|
130 |
+
super(AoA_Decoder_Core, self).__init__()
|
131 |
+
self.drop_prob_lm = opt.drop_prob_lm
|
132 |
+
self.d_model = opt.rnn_size
|
133 |
+
self.use_multi_head = opt.use_multi_head
|
134 |
+
self.multi_head_scale = opt.multi_head_scale
|
135 |
+
self.use_ctx_drop = getattr(opt, 'ctx_drop', 0)
|
136 |
+
self.out_res = getattr(opt, 'out_res', 0)
|
137 |
+
self.decoder_type = getattr(opt, 'decoder_type', 'AoA')
|
138 |
+
self.att_lstm = nn.LSTMCell(opt.input_encoding_size + opt.rnn_size, opt.rnn_size) # we, fc, h^2_t-1
|
139 |
+
self.out_drop = nn.Dropout(self.drop_prob_lm)
|
140 |
+
|
141 |
+
if self.decoder_type == 'AoA':
|
142 |
+
# AoA layer
|
143 |
+
self.att2ctx = nn.Sequential(nn.Linear(self.d_model * opt.multi_head_scale + opt.rnn_size, 2 * opt.rnn_size), nn.GLU())
|
144 |
+
elif self.decoder_type == 'LSTM':
|
145 |
+
# LSTM layer
|
146 |
+
self.att2ctx = nn.LSTMCell(self.d_model * opt.multi_head_scale + opt.rnn_size, opt.rnn_size)
|
147 |
+
else:
|
148 |
+
# Base linear layer
|
149 |
+
self.att2ctx = nn.Sequential(nn.Linear(self.d_model * opt.multi_head_scale + opt.rnn_size, opt.rnn_size), nn.ReLU())
|
150 |
+
|
151 |
+
# if opt.use_multi_head == 1: # TODO, not implemented for now
|
152 |
+
# self.attention = MultiHeadedAddAttention(opt.num_heads, opt.d_model, scale=opt.multi_head_scale)
|
153 |
+
if opt.use_multi_head == 2:
|
154 |
+
self.attention = MultiHeadedDotAttention(opt.num_heads, opt.rnn_size, project_k_v=0, scale=opt.multi_head_scale, use_output_layer=0, do_aoa=0, norm_q=1)
|
155 |
+
else:
|
156 |
+
self.attention = Attention(opt)
|
157 |
+
|
158 |
+
if self.use_ctx_drop:
|
159 |
+
self.ctx_drop = nn.Dropout(self.drop_prob_lm)
|
160 |
+
else:
|
161 |
+
self.ctx_drop = lambda x :x
|
162 |
+
|
163 |
+
def forward(self, xt, mean_feats, att_feats, p_att_feats, state, att_masks=None):
|
164 |
+
# state[0][1] is the context vector at the last step
|
165 |
+
h_att, c_att = self.att_lstm(torch.cat([xt, mean_feats + self.ctx_drop(state[0][1])], 1), (state[0][0], state[1][0]))
|
166 |
+
|
167 |
+
if self.use_multi_head == 2:
|
168 |
+
att = self.attention(h_att, p_att_feats.narrow(2, 0, self.multi_head_scale * self.d_model), p_att_feats.narrow(2, self.multi_head_scale * self.d_model, self.multi_head_scale * self.d_model), att_masks)
|
169 |
+
else:
|
170 |
+
att = self.attention(h_att, att_feats, p_att_feats, att_masks)
|
171 |
+
|
172 |
+
ctx_input = torch.cat([att, h_att], 1)
|
173 |
+
if self.decoder_type == 'LSTM':
|
174 |
+
output, c_logic = self.att2ctx(ctx_input, (state[0][1], state[1][1]))
|
175 |
+
state = (torch.stack((h_att, output)), torch.stack((c_att, c_logic)))
|
176 |
+
else:
|
177 |
+
output = self.att2ctx(ctx_input)
|
178 |
+
# save the context vector to state[0][1]
|
179 |
+
state = (torch.stack((h_att, output)), torch.stack((c_att, state[1][1])))
|
180 |
+
|
181 |
+
if self.out_res:
|
182 |
+
# add residual connection
|
183 |
+
output = output + h_att
|
184 |
+
|
185 |
+
output = self.out_drop(output)
|
186 |
+
return output, state
|
187 |
+
|
188 |
+
class AoAModel(AttModel):
|
189 |
+
def __init__(self, opt):
|
190 |
+
super(AoAModel, self).__init__(opt)
|
191 |
+
self.num_layers = 2
|
192 |
+
# mean pooling
|
193 |
+
self.use_mean_feats = getattr(opt, 'mean_feats', 1)
|
194 |
+
if opt.use_multi_head == 2:
|
195 |
+
del self.ctx2att
|
196 |
+
self.ctx2att = nn.Linear(opt.rnn_size, 2 * opt.multi_head_scale * opt.rnn_size)
|
197 |
+
|
198 |
+
if self.use_mean_feats:
|
199 |
+
del self.fc_embed
|
200 |
+
if opt.refine:
|
201 |
+
self.refiner = AoA_Refiner_Core(opt)
|
202 |
+
else:
|
203 |
+
self.refiner = lambda x,y : x
|
204 |
+
self.core = AoA_Decoder_Core(opt)
|
205 |
+
|
206 |
+
self.d_model = getattr(opt, 'd_model', opt.input_encoding_size)
|
207 |
+
|
208 |
+
|
209 |
+
def _prepare_feature(self, fc_feats, att_feats, att_masks):
|
210 |
+
att_feats, att_masks = self.clip_att(att_feats, att_masks)
|
211 |
+
|
212 |
+
# embed att feats
|
213 |
+
att_feats = pack_wrapper(self.att_embed, att_feats, att_masks)
|
214 |
+
att_feats = self.refiner(att_feats, att_masks)
|
215 |
+
|
216 |
+
if self.use_mean_feats:
|
217 |
+
# meaning pooling
|
218 |
+
if att_masks is None:
|
219 |
+
mean_feats = torch.mean(att_feats, dim=1)
|
220 |
+
else:
|
221 |
+
mean_feats = (torch.sum(att_feats * att_masks.unsqueeze(-1), 1) / torch.sum(att_masks.unsqueeze(-1), 1))
|
222 |
+
else:
|
223 |
+
mean_feats = self.fc_embed(fc_feats)
|
224 |
+
|
225 |
+
# Project the attention feats first to reduce memory and computation.
|
226 |
+
p_att_feats = self.ctx2att(att_feats)
|
227 |
+
|
228 |
+
return mean_feats, att_feats, p_att_feats, att_masks
|
captioning/models/AttEnsemble.py
ADDED
@@ -0,0 +1,90 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# This file is the implementation for ensemble evaluation.
|
2 |
+
|
3 |
+
from __future__ import absolute_import
|
4 |
+
from __future__ import division
|
5 |
+
from __future__ import print_function
|
6 |
+
|
7 |
+
import numpy as np
|
8 |
+
import torch
|
9 |
+
import torch.nn as nn
|
10 |
+
import torch.nn.functional as F
|
11 |
+
from torch.autograd import *
|
12 |
+
|
13 |
+
from .CaptionModel import CaptionModel
|
14 |
+
from .AttModel import pack_wrapper, AttModel
|
15 |
+
|
16 |
+
class AttEnsemble(AttModel):
|
17 |
+
def __init__(self, models, weights=None):
|
18 |
+
CaptionModel.__init__(self)
|
19 |
+
# super(AttEnsemble, self).__init__()
|
20 |
+
|
21 |
+
self.models = nn.ModuleList(models)
|
22 |
+
self.vocab_size = models[0].vocab_size
|
23 |
+
self.seq_length = models[0].seq_length
|
24 |
+
self.bad_endings_ix = models[0].bad_endings_ix
|
25 |
+
self.ss_prob = 0
|
26 |
+
weights = weights or [1.0] * len(self.models)
|
27 |
+
self.register_buffer('weights', torch.tensor(weights))
|
28 |
+
|
29 |
+
def init_hidden(self, batch_size):
|
30 |
+
state = [m.init_hidden(batch_size) for m in self.models]
|
31 |
+
return self.pack_state(state)
|
32 |
+
|
33 |
+
def pack_state(self, state):
|
34 |
+
self.state_lengths = [len(_) for _ in state]
|
35 |
+
return sum([list(_) for _ in state], [])
|
36 |
+
|
37 |
+
def unpack_state(self, state):
|
38 |
+
out = []
|
39 |
+
for l in self.state_lengths:
|
40 |
+
out.append(state[:l])
|
41 |
+
state = state[l:]
|
42 |
+
return out
|
43 |
+
|
44 |
+
def embed(self, it):
|
45 |
+
return [m.embed(it) for m in self.models]
|
46 |
+
|
47 |
+
def core(self, *args):
|
48 |
+
return zip(*[m.core(*_) for m, _ in zip(self.models, zip(*args))])
|
49 |
+
|
50 |
+
def get_logprobs_state(self, it, tmp_fc_feats, tmp_att_feats, tmp_p_att_feats, tmp_att_masks, state, output_logsoftmax=1):
|
51 |
+
# 'it' contains a word index
|
52 |
+
xt = self.embed(it)
|
53 |
+
|
54 |
+
state = self.unpack_state(state)
|
55 |
+
output, state = self.core(xt, tmp_fc_feats, tmp_att_feats, tmp_p_att_feats, state, tmp_att_masks)
|
56 |
+
logprobs = torch.stack([F.softmax(m.logit(output[i]), dim=1) for i,m in enumerate(self.models)], 2).mul(self.weights).div(self.weights.sum()).sum(-1).log()
|
57 |
+
|
58 |
+
return logprobs, self.pack_state(state)
|
59 |
+
|
60 |
+
def _prepare_feature(self, *args):
|
61 |
+
return tuple(zip(*[m._prepare_feature(*args) for m in self.models]))
|
62 |
+
|
63 |
+
def _old_sample_beam(self, fc_feats, att_feats, att_masks=None, opt={}):
|
64 |
+
beam_size = opt.get('beam_size', 10)
|
65 |
+
batch_size = fc_feats.size(0)
|
66 |
+
|
67 |
+
fc_feats, att_feats, p_att_feats, att_masks = self._prepare_feature(fc_feats, att_feats, att_masks)
|
68 |
+
|
69 |
+
assert beam_size <= self.vocab_size + 1, 'lets assume this for now, otherwise this corner case causes a few headaches down the road. can be dealt with in future if needed'
|
70 |
+
seq = torch.LongTensor(self.seq_length, batch_size).zero_()
|
71 |
+
seqLogprobs = torch.FloatTensor(self.seq_length, batch_size, self.vocab_size + 1)
|
72 |
+
# lets process every image independently for now, for simplicity
|
73 |
+
|
74 |
+
self.done_beams = [[] for _ in range(batch_size)]
|
75 |
+
for k in range(batch_size):
|
76 |
+
state = self.init_hidden(beam_size)
|
77 |
+
tmp_fc_feats = [fc_feats[i][k:k+1].expand(beam_size, fc_feats[i].size(1)) for i,m in enumerate(self.models)]
|
78 |
+
tmp_att_feats = [att_feats[i][k:k+1].expand(*((beam_size,)+att_feats[i].size()[1:])).contiguous() for i,m in enumerate(self.models)]
|
79 |
+
tmp_p_att_feats = [p_att_feats[i][k:k+1].expand(*((beam_size,)+p_att_feats[i].size()[1:])).contiguous() for i,m in enumerate(self.models)]
|
80 |
+
tmp_att_masks = [att_masks[i][k:k+1].expand(*((beam_size,)+att_masks[i].size()[1:])).contiguous() if att_masks[i] is not None else att_masks[i] for i,m in enumerate(self.models)]
|
81 |
+
|
82 |
+
it = fc_feats[0].data.new(beam_size).long().zero_()
|
83 |
+
logprobs, state = self.get_logprobs_state(it, tmp_fc_feats, tmp_att_feats, tmp_p_att_feats, tmp_att_masks, state)
|
84 |
+
|
85 |
+
self.done_beams[k] = self.old_beam_search(state, logprobs, tmp_fc_feats, tmp_att_feats, tmp_p_att_feats, tmp_att_masks, opt=opt)
|
86 |
+
seq[:, k] = self.done_beams[k][0]['seq'] # the first beam has highest cumulative score
|
87 |
+
seqLogprobs[:, k] = self.done_beams[k][0]['logps']
|
88 |
+
# return the samples and their log likelihoods
|
89 |
+
return seq.transpose(0, 1), seqLogprobs.transpose(0, 1)
|
90 |
+
# return the samples and their log likelihoods
|
captioning/models/AttModel.py
ADDED
@@ -0,0 +1,969 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# This file contains Att2in2, AdaAtt, AdaAttMO, UpDown model
|
2 |
+
|
3 |
+
# AdaAtt is from Knowing When to Look: Adaptive Attention via A Visual Sentinel for Image Captioning
|
4 |
+
# https://arxiv.org/abs/1612.01887
|
5 |
+
# AdaAttMO is a modified version with maxout lstm
|
6 |
+
|
7 |
+
# Att2in is from Self-critical Sequence Training for Image Captioning
|
8 |
+
# https://arxiv.org/abs/1612.00563
|
9 |
+
# In this file we only have Att2in2, which is a slightly different version of att2in,
|
10 |
+
# in which the img feature embedding and word embedding is the same as what in adaatt.
|
11 |
+
|
12 |
+
# UpDown is from Bottom-Up and Top-Down Attention for Image Captioning and VQA
|
13 |
+
# https://arxiv.org/abs/1707.07998
|
14 |
+
# However, it may not be identical to the author's architecture.
|
15 |
+
|
16 |
+
from __future__ import absolute_import
|
17 |
+
from __future__ import division
|
18 |
+
from __future__ import print_function
|
19 |
+
|
20 |
+
import numpy as np
|
21 |
+
import torch
|
22 |
+
import torch.nn as nn
|
23 |
+
import torch.nn.functional as F
|
24 |
+
from . import utils
|
25 |
+
from torch.nn.utils.rnn import PackedSequence, pack_padded_sequence, pad_packed_sequence
|
26 |
+
|
27 |
+
from .CaptionModel import CaptionModel
|
28 |
+
|
29 |
+
bad_endings = ['a','an','the','in','for','at','of','with','before','after','on','upon','near','to','is','are','am']
|
30 |
+
bad_endings += ['the']
|
31 |
+
|
32 |
+
def sort_pack_padded_sequence(input, lengths):
|
33 |
+
sorted_lengths, indices = torch.sort(lengths, descending=True)
|
34 |
+
# tmp = pack_padded_sequence(input[indices], sorted_lengths, batch_first=True)
|
35 |
+
tmp = pack_padded_sequence(input[indices], sorted_lengths.cpu(), batch_first=True)
|
36 |
+
inv_ix = indices.clone()
|
37 |
+
inv_ix[indices] = torch.arange(0,len(indices)).type_as(inv_ix)
|
38 |
+
return tmp, inv_ix
|
39 |
+
|
40 |
+
def pad_unsort_packed_sequence(input, inv_ix):
|
41 |
+
tmp, _ = pad_packed_sequence(input, batch_first=True)
|
42 |
+
tmp = tmp[inv_ix]
|
43 |
+
return tmp
|
44 |
+
|
45 |
+
def pack_wrapper(module, att_feats, att_masks):
|
46 |
+
if att_masks is not None:
|
47 |
+
packed, inv_ix = sort_pack_padded_sequence(att_feats, att_masks.data.long().sum(1))
|
48 |
+
return pad_unsort_packed_sequence(PackedSequence(module(packed[0]), packed[1]), inv_ix)
|
49 |
+
else:
|
50 |
+
return module(att_feats)
|
51 |
+
|
52 |
+
class AttModel(CaptionModel):
|
53 |
+
def __init__(self, opt):
|
54 |
+
super(AttModel, self).__init__()
|
55 |
+
self.vocab_size = opt.vocab_size
|
56 |
+
self.input_encoding_size = opt.input_encoding_size
|
57 |
+
#self.rnn_type = opt.rnn_type
|
58 |
+
self.rnn_size = opt.rnn_size
|
59 |
+
self.num_layers = opt.num_layers
|
60 |
+
self.drop_prob_lm = opt.drop_prob_lm
|
61 |
+
self.seq_length = getattr(opt, 'max_length', 20) or opt.seq_length # maximum sample length
|
62 |
+
self.fc_feat_size = opt.fc_feat_size
|
63 |
+
self.att_feat_size = opt.att_feat_size
|
64 |
+
self.att_hid_size = opt.att_hid_size
|
65 |
+
|
66 |
+
self.bos_idx = getattr(opt, 'bos_idx', 0)
|
67 |
+
self.eos_idx = getattr(opt, 'eos_idx', 0)
|
68 |
+
self.pad_idx = getattr(opt, 'pad_idx', 0)
|
69 |
+
|
70 |
+
self.use_bn = getattr(opt, 'use_bn', 0)
|
71 |
+
|
72 |
+
self.ss_prob = 0.0 # Schedule sampling probability
|
73 |
+
|
74 |
+
self.embed = nn.Sequential(nn.Embedding(self.vocab_size + 1, self.input_encoding_size),
|
75 |
+
nn.ReLU(),
|
76 |
+
nn.Dropout(self.drop_prob_lm))
|
77 |
+
self.fc_embed = nn.Sequential(nn.Linear(self.fc_feat_size, self.rnn_size),
|
78 |
+
nn.ReLU(),
|
79 |
+
nn.Dropout(self.drop_prob_lm))
|
80 |
+
self.att_embed = nn.Sequential(*(
|
81 |
+
((nn.BatchNorm1d(self.att_feat_size),) if self.use_bn else ())+
|
82 |
+
(nn.Linear(self.att_feat_size, self.rnn_size),
|
83 |
+
nn.ReLU(),
|
84 |
+
nn.Dropout(self.drop_prob_lm))+
|
85 |
+
((nn.BatchNorm1d(self.rnn_size),) if self.use_bn==2 else ())))
|
86 |
+
|
87 |
+
self.logit_layers = getattr(opt, 'logit_layers', 1)
|
88 |
+
if self.logit_layers == 1:
|
89 |
+
self.logit = nn.Linear(self.rnn_size, self.vocab_size + 1)
|
90 |
+
else:
|
91 |
+
self.logit = [[nn.Linear(self.rnn_size, self.rnn_size), nn.ReLU(), nn.Dropout(0.5)] for _ in range(opt.logit_layers - 1)]
|
92 |
+
self.logit = nn.Sequential(*(reduce(lambda x,y:x+y, self.logit) + [nn.Linear(self.rnn_size, self.vocab_size + 1)]))
|
93 |
+
self.ctx2att = nn.Linear(self.rnn_size, self.att_hid_size)
|
94 |
+
|
95 |
+
# For remove bad endding
|
96 |
+
self.vocab = opt.vocab
|
97 |
+
self.bad_endings_ix = [int(k) for k,v in self.vocab.items() if v in bad_endings]
|
98 |
+
|
99 |
+
def init_hidden(self, bsz):
|
100 |
+
weight = self.logit.weight \
|
101 |
+
if hasattr(self.logit, "weight") \
|
102 |
+
else self.logit[0].weight
|
103 |
+
return (weight.new_zeros(self.num_layers, bsz, self.rnn_size),
|
104 |
+
weight.new_zeros(self.num_layers, bsz, self.rnn_size))
|
105 |
+
|
106 |
+
def clip_att(self, att_feats, att_masks):
|
107 |
+
# Clip the length of att_masks and att_feats to the maximum length
|
108 |
+
if att_masks is not None:
|
109 |
+
max_len = att_masks.data.long().sum(1).max()
|
110 |
+
att_feats = att_feats[:, :max_len].contiguous()
|
111 |
+
att_masks = att_masks[:, :max_len].contiguous()
|
112 |
+
return att_feats, att_masks
|
113 |
+
|
114 |
+
def _prepare_feature(self, fc_feats, att_feats, att_masks):
|
115 |
+
att_feats, att_masks = self.clip_att(att_feats, att_masks)
|
116 |
+
|
117 |
+
# embed fc and att feats
|
118 |
+
fc_feats = self.fc_embed(fc_feats)
|
119 |
+
att_feats = pack_wrapper(self.att_embed, att_feats, att_masks)
|
120 |
+
|
121 |
+
# Project the attention feats first to reduce memory and computation comsumptions.
|
122 |
+
p_att_feats = self.ctx2att(att_feats)
|
123 |
+
|
124 |
+
return fc_feats, att_feats, p_att_feats, att_masks
|
125 |
+
|
126 |
+
def _forward(self, fc_feats, att_feats, seq, att_masks=None):
|
127 |
+
batch_size = fc_feats.size(0)
|
128 |
+
if seq.ndim == 3: # B * seq_per_img * seq_len
|
129 |
+
seq = seq.reshape(-1, seq.shape[2])
|
130 |
+
seq_per_img = seq.shape[0] // batch_size
|
131 |
+
state = self.init_hidden(batch_size*seq_per_img)
|
132 |
+
|
133 |
+
outputs = fc_feats.new_zeros(batch_size*seq_per_img, seq.size(1), self.vocab_size+1)
|
134 |
+
|
135 |
+
# Prepare the features
|
136 |
+
p_fc_feats, p_att_feats, pp_att_feats, p_att_masks = self._prepare_feature(fc_feats, att_feats, att_masks)
|
137 |
+
# pp_att_feats is used for attention, we cache it in advance to reduce computation cost
|
138 |
+
|
139 |
+
if seq_per_img > 1:
|
140 |
+
p_fc_feats, p_att_feats, pp_att_feats, p_att_masks = utils.repeat_tensors(seq_per_img,
|
141 |
+
[p_fc_feats, p_att_feats, pp_att_feats, p_att_masks]
|
142 |
+
)
|
143 |
+
|
144 |
+
for i in range(seq.size(1)):
|
145 |
+
if self.training and i >= 1 and self.ss_prob > 0.0: # otherwiste no need to sample
|
146 |
+
sample_prob = fc_feats.new(batch_size*seq_per_img).uniform_(0, 1)
|
147 |
+
sample_mask = sample_prob < self.ss_prob
|
148 |
+
if sample_mask.sum() == 0:
|
149 |
+
it = seq[:, i].clone()
|
150 |
+
else:
|
151 |
+
sample_ind = sample_mask.nonzero().view(-1)
|
152 |
+
it = seq[:, i].data.clone()
|
153 |
+
prob_prev = torch.exp(outputs[:, i-1].detach()) # fetch prev distribution: shape Nx(M+1)
|
154 |
+
it.index_copy_(0, sample_ind, torch.multinomial(prob_prev, 1).view(-1).index_select(0, sample_ind))
|
155 |
+
else:
|
156 |
+
it = seq[:, i].clone()
|
157 |
+
# break if all the sequences end
|
158 |
+
if i >= 1 and seq[:, i].sum() == 0:
|
159 |
+
break
|
160 |
+
|
161 |
+
output, state = self.get_logprobs_state(it, p_fc_feats, p_att_feats, pp_att_feats, p_att_masks, state)
|
162 |
+
outputs[:, i] = output
|
163 |
+
|
164 |
+
return outputs
|
165 |
+
|
166 |
+
def get_logprobs_state(self, it, fc_feats, att_feats, p_att_feats, att_masks, state, output_logsoftmax=1):
|
167 |
+
# 'it' contains a word index
|
168 |
+
xt = self.embed(it)
|
169 |
+
|
170 |
+
output, state = self.core(xt, fc_feats, att_feats, p_att_feats, state, att_masks)
|
171 |
+
if output_logsoftmax:
|
172 |
+
logprobs = F.log_softmax(self.logit(output), dim=1)
|
173 |
+
else:
|
174 |
+
logprobs = self.logit(output)
|
175 |
+
|
176 |
+
return logprobs, state
|
177 |
+
|
178 |
+
def _old_sample_beam(self, fc_feats, att_feats, att_masks=None, opt={}):
|
179 |
+
beam_size = opt.get('beam_size', 10)
|
180 |
+
group_size = opt.get('group_size', 1)
|
181 |
+
sample_n = opt.get('sample_n', 10)
|
182 |
+
# when sample_n == beam_size then each beam is a sample.
|
183 |
+
assert sample_n == 1 or sample_n == beam_size // group_size, 'when beam search, sample_n == 1 or beam search'
|
184 |
+
batch_size = fc_feats.size(0)
|
185 |
+
|
186 |
+
p_fc_feats, p_att_feats, pp_att_feats, p_att_masks = self._prepare_feature(fc_feats, att_feats, att_masks)
|
187 |
+
|
188 |
+
assert beam_size <= self.vocab_size + 1, 'lets assume this for now, otherwise this corner case causes a few headaches down the road. can be dealt with in future if needed'
|
189 |
+
seq = fc_feats.new_full((batch_size*sample_n, self.seq_length), self.pad_idx, dtype=torch.long)
|
190 |
+
seqLogprobs = fc_feats.new_zeros(batch_size*sample_n, self.seq_length, self.vocab_size + 1)
|
191 |
+
# lets process every image independently for now, for simplicity
|
192 |
+
|
193 |
+
self.done_beams = [[] for _ in range(batch_size)]
|
194 |
+
for k in range(batch_size):
|
195 |
+
state = self.init_hidden(beam_size)
|
196 |
+
tmp_fc_feats, tmp_att_feats, tmp_p_att_feats, tmp_att_masks = utils.repeat_tensors(beam_size,
|
197 |
+
[p_fc_feats[k:k+1], p_att_feats[k:k+1], pp_att_feats[k:k+1], p_att_masks[k:k+1] if att_masks is not None else None]
|
198 |
+
)
|
199 |
+
|
200 |
+
for t in range(1):
|
201 |
+
if t == 0: # input <bos>
|
202 |
+
it = fc_feats.new_full([beam_size], self.bos_idx, dtype=torch.long)
|
203 |
+
|
204 |
+
logprobs, state = self.get_logprobs_state(it, tmp_fc_feats, tmp_att_feats, tmp_p_att_feats, tmp_att_masks, state)
|
205 |
+
|
206 |
+
self.done_beams[k] = self.old_beam_search(state, logprobs, tmp_fc_feats, tmp_att_feats, tmp_p_att_feats, tmp_att_masks, opt=opt)
|
207 |
+
if sample_n == beam_size:
|
208 |
+
for _n in range(sample_n):
|
209 |
+
seq[k*sample_n+_n, :] = self.done_beams[k][_n]['seq']
|
210 |
+
seqLogprobs[k*sample_n+_n, :] = self.done_beams[k][_n]['logps']
|
211 |
+
else:
|
212 |
+
seq[k, :] = self.done_beams[k][0]['seq'] # the first beam has highest cumulative score
|
213 |
+
seqLogprobs[k, :] = self.done_beams[k][0]['logps']
|
214 |
+
# return the samples and their log likelihoods
|
215 |
+
return seq, seqLogprobs
|
216 |
+
|
217 |
+
|
218 |
+
def _sample_beam(self, fc_feats, att_feats, att_masks=None, opt={}):
|
219 |
+
beam_size = opt.get('beam_size', 10)
|
220 |
+
group_size = opt.get('group_size', 1)
|
221 |
+
sample_n = opt.get('sample_n', 10)
|
222 |
+
# when sample_n == beam_size then each beam is a sample.
|
223 |
+
assert sample_n == 1 or sample_n == beam_size // group_size, 'when beam search, sample_n == 1 or beam search'
|
224 |
+
batch_size = fc_feats.size(0)
|
225 |
+
|
226 |
+
p_fc_feats, p_att_feats, pp_att_feats, p_att_masks = self._prepare_feature(fc_feats, att_feats, att_masks)
|
227 |
+
|
228 |
+
assert beam_size <= self.vocab_size + 1, 'lets assume this for now, otherwise this corner case causes a few headaches down the road. can be dealt with in future if needed'
|
229 |
+
seq = fc_feats.new_full((batch_size*sample_n, self.seq_length), self.pad_idx, dtype=torch.long)
|
230 |
+
seqLogprobs = fc_feats.new_zeros(batch_size*sample_n, self.seq_length, self.vocab_size + 1)
|
231 |
+
# lets process every image independently for now, for simplicity
|
232 |
+
|
233 |
+
self.done_beams = [[] for _ in range(batch_size)]
|
234 |
+
|
235 |
+
state = self.init_hidden(batch_size)
|
236 |
+
|
237 |
+
# first step, feed bos
|
238 |
+
it = fc_feats.new_full([batch_size], self.bos_idx, dtype=torch.long)
|
239 |
+
logprobs, state = self.get_logprobs_state(it, p_fc_feats, p_att_feats, pp_att_feats, p_att_masks, state)
|
240 |
+
|
241 |
+
p_fc_feats, p_att_feats, pp_att_feats, p_att_masks = utils.repeat_tensors(beam_size,
|
242 |
+
[p_fc_feats, p_att_feats, pp_att_feats, p_att_masks]
|
243 |
+
)
|
244 |
+
self.done_beams = self.beam_search(state, logprobs, p_fc_feats, p_att_feats, pp_att_feats, p_att_masks, opt=opt)
|
245 |
+
for k in range(batch_size):
|
246 |
+
if sample_n == beam_size:
|
247 |
+
for _n in range(sample_n):
|
248 |
+
seq_len = self.done_beams[k][_n]['seq'].shape[0]
|
249 |
+
seq[k*sample_n+_n, :seq_len] = self.done_beams[k][_n]['seq']
|
250 |
+
seqLogprobs[k*sample_n+_n, :seq_len] = self.done_beams[k][_n]['logps']
|
251 |
+
else:
|
252 |
+
seq_len = self.done_beams[k][0]['seq'].shape[0]
|
253 |
+
seq[k, :seq_len] = self.done_beams[k][0]['seq'] # the first beam has highest cumulative score
|
254 |
+
seqLogprobs[k, :seq_len] = self.done_beams[k][0]['logps']
|
255 |
+
# return the samples and their log likelihoods
|
256 |
+
return seq, seqLogprobs
|
257 |
+
|
258 |
+
def _sample(self, fc_feats, att_feats, att_masks=None, opt={}):
|
259 |
+
|
260 |
+
sample_method = opt.get('sample_method', 'greedy')
|
261 |
+
beam_size = opt.get('beam_size', 1)
|
262 |
+
temperature = opt.get('temperature', 1.0)
|
263 |
+
sample_n = int(opt.get('sample_n', 1))
|
264 |
+
group_size = opt.get('group_size', 1)
|
265 |
+
output_logsoftmax = opt.get('output_logsoftmax', 1)
|
266 |
+
decoding_constraint = opt.get('decoding_constraint', 0)
|
267 |
+
block_trigrams = opt.get('block_trigrams', 0)
|
268 |
+
remove_bad_endings = opt.get('remove_bad_endings', 0)
|
269 |
+
if beam_size > 1 and sample_method in ['greedy', 'beam_search']:
|
270 |
+
return self._sample_beam(fc_feats, att_feats, att_masks, opt)
|
271 |
+
if group_size > 1:
|
272 |
+
return self._diverse_sample(fc_feats, att_feats, att_masks, opt)
|
273 |
+
|
274 |
+
batch_size = fc_feats.size(0)
|
275 |
+
state = self.init_hidden(batch_size*sample_n)
|
276 |
+
|
277 |
+
p_fc_feats, p_att_feats, pp_att_feats, p_att_masks = self._prepare_feature(fc_feats, att_feats, att_masks)
|
278 |
+
|
279 |
+
if sample_n > 1:
|
280 |
+
p_fc_feats, p_att_feats, pp_att_feats, p_att_masks = utils.repeat_tensors(sample_n,
|
281 |
+
[p_fc_feats, p_att_feats, pp_att_feats, p_att_masks]
|
282 |
+
)
|
283 |
+
|
284 |
+
trigrams = [] # will be a list of batch_size dictionaries
|
285 |
+
|
286 |
+
seq = fc_feats.new_full((batch_size*sample_n, self.seq_length), self.pad_idx, dtype=torch.long)
|
287 |
+
seqLogprobs = fc_feats.new_zeros(batch_size*sample_n, self.seq_length, self.vocab_size + 1)
|
288 |
+
for t in range(self.seq_length + 1):
|
289 |
+
if t == 0: # input <bos>
|
290 |
+
it = fc_feats.new_full([batch_size*sample_n], self.bos_idx, dtype=torch.long)
|
291 |
+
|
292 |
+
logprobs, state = self.get_logprobs_state(it, p_fc_feats, p_att_feats, pp_att_feats, p_att_masks, state, output_logsoftmax=output_logsoftmax)
|
293 |
+
|
294 |
+
if decoding_constraint and t > 0:
|
295 |
+
tmp = logprobs.new_zeros(logprobs.size())
|
296 |
+
tmp.scatter_(1, seq[:,t-1].data.unsqueeze(1), float('-inf'))
|
297 |
+
logprobs = logprobs + tmp
|
298 |
+
|
299 |
+
if remove_bad_endings and t > 0:
|
300 |
+
tmp = logprobs.new_zeros(logprobs.size())
|
301 |
+
prev_bad = np.isin(seq[:,t-1].data.cpu().numpy(), self.bad_endings_ix)
|
302 |
+
# Make it impossible to generate bad_endings
|
303 |
+
tmp[torch.from_numpy(prev_bad.astype('uint8')), 0] = float('-inf')
|
304 |
+
logprobs = logprobs + tmp
|
305 |
+
|
306 |
+
# Mess with trigrams
|
307 |
+
# Copy from https://github.com/lukemelas/image-paragraph-captioning
|
308 |
+
if block_trigrams and t >= 3:
|
309 |
+
# Store trigram generated at last step
|
310 |
+
prev_two_batch = seq[:,t-3:t-1]
|
311 |
+
for i in range(batch_size): # = seq.size(0)
|
312 |
+
prev_two = (prev_two_batch[i][0].item(), prev_two_batch[i][1].item())
|
313 |
+
current = seq[i][t-1]
|
314 |
+
if t == 3: # initialize
|
315 |
+
trigrams.append({prev_two: [current]}) # {LongTensor: list containing 1 int}
|
316 |
+
elif t > 3:
|
317 |
+
if prev_two in trigrams[i]: # add to list
|
318 |
+
trigrams[i][prev_two].append(current)
|
319 |
+
else: # create list
|
320 |
+
trigrams[i][prev_two] = [current]
|
321 |
+
# Block used trigrams at next step
|
322 |
+
prev_two_batch = seq[:,t-2:t]
|
323 |
+
mask = torch.zeros(logprobs.size(), requires_grad=False).to(logprobs.device) # batch_size x vocab_size
|
324 |
+
for i in range(batch_size):
|
325 |
+
prev_two = (prev_two_batch[i][0].item(), prev_two_batch[i][1].item())
|
326 |
+
if prev_two in trigrams[i]:
|
327 |
+
for j in trigrams[i][prev_two]:
|
328 |
+
mask[i,j] += 1
|
329 |
+
# Apply mask to log probs
|
330 |
+
#logprobs = logprobs - (mask * 1e9)
|
331 |
+
alpha = 2.0 # = 4
|
332 |
+
logprobs = logprobs + (mask * -0.693 * alpha) # ln(1/2) * alpha (alpha -> infty works best)
|
333 |
+
|
334 |
+
# sample the next word
|
335 |
+
if t == self.seq_length: # skip if we achieve maximum length
|
336 |
+
break
|
337 |
+
it, sampleLogprobs = self.sample_next_word(logprobs, sample_method, temperature)
|
338 |
+
|
339 |
+
# stop when all finished
|
340 |
+
if t == 0:
|
341 |
+
unfinished = it != self.eos_idx
|
342 |
+
else:
|
343 |
+
it[~unfinished] = self.pad_idx # This allows eos_idx not being overwritten to 0
|
344 |
+
logprobs = logprobs * unfinished.unsqueeze(1).to(logprobs)
|
345 |
+
unfinished = unfinished & (it != self.eos_idx)
|
346 |
+
seq[:,t] = it
|
347 |
+
seqLogprobs[:,t] = logprobs
|
348 |
+
# quit loop if all sequences have finished
|
349 |
+
if unfinished.sum() == 0:
|
350 |
+
break
|
351 |
+
|
352 |
+
return seq, seqLogprobs
|
353 |
+
|
354 |
+
def _diverse_sample(self, fc_feats, att_feats, att_masks=None, opt={}):
|
355 |
+
|
356 |
+
sample_method = opt.get('sample_method', 'greedy')
|
357 |
+
beam_size = opt.get('beam_size', 1)
|
358 |
+
temperature = opt.get('temperature', 1.0)
|
359 |
+
group_size = opt.get('group_size', 1)
|
360 |
+
diversity_lambda = opt.get('diversity_lambda', 0.5)
|
361 |
+
decoding_constraint = opt.get('decoding_constraint', 0)
|
362 |
+
block_trigrams = opt.get('block_trigrams', 0)
|
363 |
+
remove_bad_endings = opt.get('remove_bad_endings', 0)
|
364 |
+
|
365 |
+
batch_size = fc_feats.size(0)
|
366 |
+
state = self.init_hidden(batch_size)
|
367 |
+
|
368 |
+
p_fc_feats, p_att_feats, pp_att_feats, p_att_masks = self._prepare_feature(fc_feats, att_feats, att_masks)
|
369 |
+
|
370 |
+
trigrams_table = [[] for _ in range(group_size)] # will be a list of batch_size dictionaries
|
371 |
+
|
372 |
+
seq_table = [fc_feats.new_full((batch_size, self.seq_length), self.pad_idx, dtype=torch.long) for _ in range(group_size)]
|
373 |
+
seqLogprobs_table = [fc_feats.new_zeros(batch_size, self.seq_length) for _ in range(group_size)]
|
374 |
+
state_table = [self.init_hidden(batch_size) for _ in range(group_size)]
|
375 |
+
|
376 |
+
for tt in range(self.seq_length + group_size):
|
377 |
+
for divm in range(group_size):
|
378 |
+
t = tt - divm
|
379 |
+
seq = seq_table[divm]
|
380 |
+
seqLogprobs = seqLogprobs_table[divm]
|
381 |
+
trigrams = trigrams_table[divm]
|
382 |
+
if t >= 0 and t <= self.seq_length-1:
|
383 |
+
if t == 0: # input <bos>
|
384 |
+
it = fc_feats.new_full([batch_size], self.bos_idx, dtype=torch.long)
|
385 |
+
else:
|
386 |
+
it = seq[:, t-1] # changed
|
387 |
+
|
388 |
+
logprobs, state_table[divm] = self.get_logprobs_state(it, p_fc_feats, p_att_feats, pp_att_feats, p_att_masks, state_table[divm]) # changed
|
389 |
+
logprobs = F.log_softmax(logprobs / temperature, dim=-1)
|
390 |
+
|
391 |
+
# Add diversity
|
392 |
+
if divm > 0:
|
393 |
+
unaug_logprobs = logprobs.clone()
|
394 |
+
for prev_choice in range(divm):
|
395 |
+
prev_decisions = seq_table[prev_choice][:, t]
|
396 |
+
logprobs[:, prev_decisions] = logprobs[:, prev_decisions] - diversity_lambda
|
397 |
+
|
398 |
+
if decoding_constraint and t > 0:
|
399 |
+
tmp = logprobs.new_zeros(logprobs.size())
|
400 |
+
tmp.scatter_(1, seq[:,t-1].data.unsqueeze(1), float('-inf'))
|
401 |
+
logprobs = logprobs + tmp
|
402 |
+
|
403 |
+
if remove_bad_endings and t > 0:
|
404 |
+
tmp = logprobs.new_zeros(logprobs.size())
|
405 |
+
prev_bad = np.isin(seq[:,t-1].data.cpu().numpy(), self.bad_endings_ix)
|
406 |
+
# Impossible to generate remove_bad_endings
|
407 |
+
tmp[torch.from_numpy(prev_bad.astype('uint8')), 0] = float('-inf')
|
408 |
+
logprobs = logprobs + tmp
|
409 |
+
|
410 |
+
# Mess with trigrams
|
411 |
+
if block_trigrams and t >= 3:
|
412 |
+
# Store trigram generated at last step
|
413 |
+
prev_two_batch = seq[:,t-3:t-1]
|
414 |
+
for i in range(batch_size): # = seq.size(0)
|
415 |
+
prev_two = (prev_two_batch[i][0].item(), prev_two_batch[i][1].item())
|
416 |
+
current = seq[i][t-1]
|
417 |
+
if t == 3: # initialize
|
418 |
+
trigrams.append({prev_two: [current]}) # {LongTensor: list containing 1 int}
|
419 |
+
elif t > 3:
|
420 |
+
if prev_two in trigrams[i]: # add to list
|
421 |
+
trigrams[i][prev_two].append(current)
|
422 |
+
else: # create list
|
423 |
+
trigrams[i][prev_two] = [current]
|
424 |
+
# Block used trigrams at next step
|
425 |
+
prev_two_batch = seq[:,t-2:t]
|
426 |
+
mask = torch.zeros(logprobs.size(), requires_grad=False).cuda() # batch_size x vocab_size
|
427 |
+
for i in range(batch_size):
|
428 |
+
prev_two = (prev_two_batch[i][0].item(), prev_two_batch[i][1].item())
|
429 |
+
if prev_two in trigrams[i]:
|
430 |
+
for j in trigrams[i][prev_two]:
|
431 |
+
mask[i,j] += 1
|
432 |
+
# Apply mask to log probs
|
433 |
+
#logprobs = logprobs - (mask * 1e9)
|
434 |
+
alpha = 2.0 # = 4
|
435 |
+
logprobs = logprobs + (mask * -0.693 * alpha) # ln(1/2) * alpha (alpha -> infty works best)
|
436 |
+
|
437 |
+
it, sampleLogprobs = self.sample_next_word(logprobs, sample_method, 1)
|
438 |
+
|
439 |
+
# stop when all finished
|
440 |
+
if t == 0:
|
441 |
+
unfinished = it != self.eos_idx
|
442 |
+
else:
|
443 |
+
unfinished = (seq[:,t-1] != self.pad_idx) & (seq[:,t-1] != self.eos_idx)
|
444 |
+
it[~unfinished] = self.pad_idx
|
445 |
+
unfinished = unfinished & (it != self.eos_idx) # changed
|
446 |
+
seq[:,t] = it
|
447 |
+
seqLogprobs[:,t] = sampleLogprobs.view(-1)
|
448 |
+
|
449 |
+
return torch.stack(seq_table, 1).reshape(batch_size * group_size, -1), torch.stack(seqLogprobs_table, 1).reshape(batch_size * group_size, -1)
|
450 |
+
|
451 |
+
class AdaAtt_lstm(nn.Module):
|
452 |
+
def __init__(self, opt, use_maxout=True):
|
453 |
+
super(AdaAtt_lstm, self).__init__()
|
454 |
+
self.input_encoding_size = opt.input_encoding_size
|
455 |
+
#self.rnn_type = opt.rnn_type
|
456 |
+
self.rnn_size = opt.rnn_size
|
457 |
+
self.num_layers = opt.num_layers
|
458 |
+
self.drop_prob_lm = opt.drop_prob_lm
|
459 |
+
self.fc_feat_size = opt.fc_feat_size
|
460 |
+
self.att_feat_size = opt.att_feat_size
|
461 |
+
self.att_hid_size = opt.att_hid_size
|
462 |
+
|
463 |
+
self.use_maxout = use_maxout
|
464 |
+
|
465 |
+
# Build a LSTM
|
466 |
+
self.w2h = nn.Linear(self.input_encoding_size, (4+(use_maxout==True)) * self.rnn_size)
|
467 |
+
self.v2h = nn.Linear(self.rnn_size, (4+(use_maxout==True)) * self.rnn_size)
|
468 |
+
|
469 |
+
self.i2h = nn.ModuleList([nn.Linear(self.rnn_size, (4+(use_maxout==True)) * self.rnn_size) for _ in range(self.num_layers - 1)])
|
470 |
+
self.h2h = nn.ModuleList([nn.Linear(self.rnn_size, (4+(use_maxout==True)) * self.rnn_size) for _ in range(self.num_layers)])
|
471 |
+
|
472 |
+
# Layers for getting the fake region
|
473 |
+
if self.num_layers == 1:
|
474 |
+
self.r_w2h = nn.Linear(self.input_encoding_size, self.rnn_size)
|
475 |
+
self.r_v2h = nn.Linear(self.rnn_size, self.rnn_size)
|
476 |
+
else:
|
477 |
+
self.r_i2h = nn.Linear(self.rnn_size, self.rnn_size)
|
478 |
+
self.r_h2h = nn.Linear(self.rnn_size, self.rnn_size)
|
479 |
+
|
480 |
+
|
481 |
+
def forward(self, xt, img_fc, state):
|
482 |
+
|
483 |
+
hs = []
|
484 |
+
cs = []
|
485 |
+
for L in range(self.num_layers):
|
486 |
+
# c,h from previous timesteps
|
487 |
+
prev_h = state[0][L]
|
488 |
+
prev_c = state[1][L]
|
489 |
+
# the input to this layer
|
490 |
+
if L == 0:
|
491 |
+
x = xt
|
492 |
+
i2h = self.w2h(x) + self.v2h(img_fc)
|
493 |
+
else:
|
494 |
+
x = hs[-1]
|
495 |
+
x = F.dropout(x, self.drop_prob_lm, self.training)
|
496 |
+
i2h = self.i2h[L-1](x)
|
497 |
+
|
498 |
+
all_input_sums = i2h+self.h2h[L](prev_h)
|
499 |
+
|
500 |
+
sigmoid_chunk = all_input_sums.narrow(1, 0, 3 * self.rnn_size)
|
501 |
+
sigmoid_chunk = torch.sigmoid(sigmoid_chunk)
|
502 |
+
# decode the gates
|
503 |
+
in_gate = sigmoid_chunk.narrow(1, 0, self.rnn_size)
|
504 |
+
forget_gate = sigmoid_chunk.narrow(1, self.rnn_size, self.rnn_size)
|
505 |
+
out_gate = sigmoid_chunk.narrow(1, self.rnn_size * 2, self.rnn_size)
|
506 |
+
# decode the write inputs
|
507 |
+
if not self.use_maxout:
|
508 |
+
in_transform = torch.tanh(all_input_sums.narrow(1, 3 * self.rnn_size, self.rnn_size))
|
509 |
+
else:
|
510 |
+
in_transform = all_input_sums.narrow(1, 3 * self.rnn_size, 2 * self.rnn_size)
|
511 |
+
in_transform = torch.max(\
|
512 |
+
in_transform.narrow(1, 0, self.rnn_size),
|
513 |
+
in_transform.narrow(1, self.rnn_size, self.rnn_size))
|
514 |
+
# perform the LSTM update
|
515 |
+
next_c = forget_gate * prev_c + in_gate * in_transform
|
516 |
+
# gated cells form the output
|
517 |
+
tanh_nex_c = torch.tanh(next_c)
|
518 |
+
next_h = out_gate * tanh_nex_c
|
519 |
+
if L == self.num_layers-1:
|
520 |
+
if L == 0:
|
521 |
+
i2h = self.r_w2h(x) + self.r_v2h(img_fc)
|
522 |
+
else:
|
523 |
+
i2h = self.r_i2h(x)
|
524 |
+
n5 = i2h+self.r_h2h(prev_h)
|
525 |
+
fake_region = torch.sigmoid(n5) * tanh_nex_c
|
526 |
+
|
527 |
+
cs.append(next_c)
|
528 |
+
hs.append(next_h)
|
529 |
+
|
530 |
+
# set up the decoder
|
531 |
+
top_h = hs[-1]
|
532 |
+
top_h = F.dropout(top_h, self.drop_prob_lm, self.training)
|
533 |
+
fake_region = F.dropout(fake_region, self.drop_prob_lm, self.training)
|
534 |
+
|
535 |
+
state = (torch.cat([_.unsqueeze(0) for _ in hs], 0),
|
536 |
+
torch.cat([_.unsqueeze(0) for _ in cs], 0))
|
537 |
+
return top_h, fake_region, state
|
538 |
+
|
539 |
+
class AdaAtt_attention(nn.Module):
|
540 |
+
def __init__(self, opt):
|
541 |
+
super(AdaAtt_attention, self).__init__()
|
542 |
+
self.input_encoding_size = opt.input_encoding_size
|
543 |
+
#self.rnn_type = opt.rnn_type
|
544 |
+
self.rnn_size = opt.rnn_size
|
545 |
+
self.drop_prob_lm = opt.drop_prob_lm
|
546 |
+
self.att_hid_size = opt.att_hid_size
|
547 |
+
|
548 |
+
# fake region embed
|
549 |
+
self.fr_linear = nn.Sequential(
|
550 |
+
nn.Linear(self.rnn_size, self.input_encoding_size),
|
551 |
+
nn.ReLU(),
|
552 |
+
nn.Dropout(self.drop_prob_lm))
|
553 |
+
self.fr_embed = nn.Linear(self.input_encoding_size, self.att_hid_size)
|
554 |
+
|
555 |
+
# h out embed
|
556 |
+
self.ho_linear = nn.Sequential(
|
557 |
+
nn.Linear(self.rnn_size, self.input_encoding_size),
|
558 |
+
nn.Tanh(),
|
559 |
+
nn.Dropout(self.drop_prob_lm))
|
560 |
+
self.ho_embed = nn.Linear(self.input_encoding_size, self.att_hid_size)
|
561 |
+
|
562 |
+
self.alpha_net = nn.Linear(self.att_hid_size, 1)
|
563 |
+
self.att2h = nn.Linear(self.rnn_size, self.rnn_size)
|
564 |
+
|
565 |
+
def forward(self, h_out, fake_region, conv_feat, conv_feat_embed, att_masks=None):
|
566 |
+
|
567 |
+
# View into three dimensions
|
568 |
+
att_size = conv_feat.numel() // conv_feat.size(0) // self.rnn_size
|
569 |
+
conv_feat = conv_feat.view(-1, att_size, self.rnn_size)
|
570 |
+
conv_feat_embed = conv_feat_embed.view(-1, att_size, self.att_hid_size)
|
571 |
+
|
572 |
+
# view neighbor from bach_size * neighbor_num x rnn_size to bach_size x rnn_size * neighbor_num
|
573 |
+
fake_region = self.fr_linear(fake_region)
|
574 |
+
fake_region_embed = self.fr_embed(fake_region)
|
575 |
+
|
576 |
+
h_out_linear = self.ho_linear(h_out)
|
577 |
+
h_out_embed = self.ho_embed(h_out_linear)
|
578 |
+
|
579 |
+
txt_replicate = h_out_embed.unsqueeze(1).expand(h_out_embed.size(0), att_size + 1, h_out_embed.size(1))
|
580 |
+
|
581 |
+
img_all = torch.cat([fake_region.view(-1,1,self.input_encoding_size), conv_feat], 1)
|
582 |
+
img_all_embed = torch.cat([fake_region_embed.view(-1,1,self.input_encoding_size), conv_feat_embed], 1)
|
583 |
+
|
584 |
+
hA = torch.tanh(img_all_embed + txt_replicate)
|
585 |
+
hA = F.dropout(hA,self.drop_prob_lm, self.training)
|
586 |
+
|
587 |
+
hAflat = self.alpha_net(hA.view(-1, self.att_hid_size))
|
588 |
+
PI = F.softmax(hAflat.view(-1, att_size + 1), dim=1)
|
589 |
+
|
590 |
+
if att_masks is not None:
|
591 |
+
att_masks = att_masks.view(-1, att_size)
|
592 |
+
PI = PI * torch.cat([att_masks[:,:1], att_masks], 1) # assume one one at the first time step.
|
593 |
+
PI = PI / PI.sum(1, keepdim=True)
|
594 |
+
|
595 |
+
visAtt = torch.bmm(PI.unsqueeze(1), img_all)
|
596 |
+
visAttdim = visAtt.squeeze(1)
|
597 |
+
|
598 |
+
atten_out = visAttdim + h_out_linear
|
599 |
+
|
600 |
+
h = torch.tanh(self.att2h(atten_out))
|
601 |
+
h = F.dropout(h, self.drop_prob_lm, self.training)
|
602 |
+
return h
|
603 |
+
|
604 |
+
class AdaAttCore(nn.Module):
|
605 |
+
def __init__(self, opt, use_maxout=False):
|
606 |
+
super(AdaAttCore, self).__init__()
|
607 |
+
self.lstm = AdaAtt_lstm(opt, use_maxout)
|
608 |
+
self.attention = AdaAtt_attention(opt)
|
609 |
+
|
610 |
+
def forward(self, xt, fc_feats, att_feats, p_att_feats, state, att_masks=None):
|
611 |
+
h_out, p_out, state = self.lstm(xt, fc_feats, state)
|
612 |
+
atten_out = self.attention(h_out, p_out, att_feats, p_att_feats, att_masks)
|
613 |
+
return atten_out, state
|
614 |
+
|
615 |
+
class UpDownCore(nn.Module):
|
616 |
+
def __init__(self, opt, use_maxout=False):
|
617 |
+
super(UpDownCore, self).__init__()
|
618 |
+
self.drop_prob_lm = opt.drop_prob_lm
|
619 |
+
|
620 |
+
self.att_lstm = nn.LSTMCell(opt.input_encoding_size + opt.rnn_size * 2, opt.rnn_size) # we, fc, h^2_t-1
|
621 |
+
self.lang_lstm = nn.LSTMCell(opt.rnn_size * 2, opt.rnn_size) # h^1_t, \hat v
|
622 |
+
self.attention = Attention(opt)
|
623 |
+
|
624 |
+
def forward(self, xt, fc_feats, att_feats, p_att_feats, state, att_masks=None):
|
625 |
+
prev_h = state[0][-1]
|
626 |
+
att_lstm_input = torch.cat([prev_h, fc_feats, xt], 1)
|
627 |
+
|
628 |
+
h_att, c_att = self.att_lstm(att_lstm_input, (state[0][0], state[1][0]))
|
629 |
+
|
630 |
+
att = self.attention(h_att, att_feats, p_att_feats, att_masks)
|
631 |
+
|
632 |
+
lang_lstm_input = torch.cat([att, h_att], 1)
|
633 |
+
# lang_lstm_input = torch.cat([att, F.dropout(h_att, self.drop_prob_lm, self.training)], 1) ?????
|
634 |
+
|
635 |
+
h_lang, c_lang = self.lang_lstm(lang_lstm_input, (state[0][1], state[1][1]))
|
636 |
+
|
637 |
+
output = F.dropout(h_lang, self.drop_prob_lm, self.training)
|
638 |
+
state = (torch.stack([h_att, h_lang]), torch.stack([c_att, c_lang]))
|
639 |
+
|
640 |
+
return output, state
|
641 |
+
|
642 |
+
|
643 |
+
############################################################################
|
644 |
+
# Notice:
|
645 |
+
# StackAtt and DenseAtt are models that I randomly designed.
|
646 |
+
# They are not related to any paper.
|
647 |
+
############################################################################
|
648 |
+
|
649 |
+
from .FCModel import LSTMCore
|
650 |
+
class StackAttCore(nn.Module):
|
651 |
+
def __init__(self, opt, use_maxout=False):
|
652 |
+
super(StackAttCore, self).__init__()
|
653 |
+
self.drop_prob_lm = opt.drop_prob_lm
|
654 |
+
|
655 |
+
# self.att0 = Attention(opt)
|
656 |
+
self.att1 = Attention(opt)
|
657 |
+
self.att2 = Attention(opt)
|
658 |
+
|
659 |
+
opt_input_encoding_size = opt.input_encoding_size
|
660 |
+
opt.input_encoding_size = opt.input_encoding_size + opt.rnn_size
|
661 |
+
self.lstm0 = LSTMCore(opt) # att_feat + word_embedding
|
662 |
+
opt.input_encoding_size = opt.rnn_size * 2
|
663 |
+
self.lstm1 = LSTMCore(opt)
|
664 |
+
self.lstm2 = LSTMCore(opt)
|
665 |
+
opt.input_encoding_size = opt_input_encoding_size
|
666 |
+
|
667 |
+
# self.emb1 = nn.Linear(opt.rnn_size, opt.rnn_size)
|
668 |
+
self.emb2 = nn.Linear(opt.rnn_size, opt.rnn_size)
|
669 |
+
|
670 |
+
def forward(self, xt, fc_feats, att_feats, p_att_feats, state, att_masks=None):
|
671 |
+
# att_res_0 = self.att0(state[0][-1], att_feats, p_att_feats, att_masks)
|
672 |
+
h_0, state_0 = self.lstm0(torch.cat([xt,fc_feats],1), [state[0][0:1], state[1][0:1]])
|
673 |
+
att_res_1 = self.att1(h_0, att_feats, p_att_feats, att_masks)
|
674 |
+
h_1, state_1 = self.lstm1(torch.cat([h_0,att_res_1],1), [state[0][1:2], state[1][1:2]])
|
675 |
+
att_res_2 = self.att2(h_1 + self.emb2(att_res_1), att_feats, p_att_feats, att_masks)
|
676 |
+
h_2, state_2 = self.lstm2(torch.cat([h_1,att_res_2],1), [state[0][2:3], state[1][2:3]])
|
677 |
+
|
678 |
+
return h_2, [torch.cat(_, 0) for _ in zip(state_0, state_1, state_2)]
|
679 |
+
|
680 |
+
class DenseAttCore(nn.Module):
|
681 |
+
def __init__(self, opt, use_maxout=False):
|
682 |
+
super(DenseAttCore, self).__init__()
|
683 |
+
self.drop_prob_lm = opt.drop_prob_lm
|
684 |
+
|
685 |
+
# self.att0 = Attention(opt)
|
686 |
+
self.att1 = Attention(opt)
|
687 |
+
self.att2 = Attention(opt)
|
688 |
+
|
689 |
+
opt_input_encoding_size = opt.input_encoding_size
|
690 |
+
opt.input_encoding_size = opt.input_encoding_size + opt.rnn_size
|
691 |
+
self.lstm0 = LSTMCore(opt) # att_feat + word_embedding
|
692 |
+
opt.input_encoding_size = opt.rnn_size * 2
|
693 |
+
self.lstm1 = LSTMCore(opt)
|
694 |
+
self.lstm2 = LSTMCore(opt)
|
695 |
+
opt.input_encoding_size = opt_input_encoding_size
|
696 |
+
|
697 |
+
# self.emb1 = nn.Linear(opt.rnn_size, opt.rnn_size)
|
698 |
+
self.emb2 = nn.Linear(opt.rnn_size, opt.rnn_size)
|
699 |
+
|
700 |
+
# fuse h_0 and h_1
|
701 |
+
self.fusion1 = nn.Sequential(nn.Linear(opt.rnn_size*2, opt.rnn_size),
|
702 |
+
nn.ReLU(),
|
703 |
+
nn.Dropout(opt.drop_prob_lm))
|
704 |
+
# fuse h_0, h_1 and h_2
|
705 |
+
self.fusion2 = nn.Sequential(nn.Linear(opt.rnn_size*3, opt.rnn_size),
|
706 |
+
nn.ReLU(),
|
707 |
+
nn.Dropout(opt.drop_prob_lm))
|
708 |
+
|
709 |
+
def forward(self, xt, fc_feats, att_feats, p_att_feats, state, att_masks=None):
|
710 |
+
# att_res_0 = self.att0(state[0][-1], att_feats, p_att_feats, att_masks)
|
711 |
+
h_0, state_0 = self.lstm0(torch.cat([xt,fc_feats],1), [state[0][0:1], state[1][0:1]])
|
712 |
+
att_res_1 = self.att1(h_0, att_feats, p_att_feats, att_masks)
|
713 |
+
h_1, state_1 = self.lstm1(torch.cat([h_0,att_res_1],1), [state[0][1:2], state[1][1:2]])
|
714 |
+
att_res_2 = self.att2(h_1 + self.emb2(att_res_1), att_feats, p_att_feats, att_masks)
|
715 |
+
h_2, state_2 = self.lstm2(torch.cat([self.fusion1(torch.cat([h_0, h_1], 1)),att_res_2],1), [state[0][2:3], state[1][2:3]])
|
716 |
+
|
717 |
+
return self.fusion2(torch.cat([h_0, h_1, h_2], 1)), [torch.cat(_, 0) for _ in zip(state_0, state_1, state_2)]
|
718 |
+
|
719 |
+
class Attention(nn.Module):
|
720 |
+
def __init__(self, opt):
|
721 |
+
super(Attention, self).__init__()
|
722 |
+
self.rnn_size = opt.rnn_size
|
723 |
+
self.att_hid_size = opt.att_hid_size
|
724 |
+
|
725 |
+
self.h2att = nn.Linear(self.rnn_size, self.att_hid_size)
|
726 |
+
self.alpha_net = nn.Linear(self.att_hid_size, 1)
|
727 |
+
|
728 |
+
def forward(self, h, att_feats, p_att_feats, att_masks=None):
|
729 |
+
# The p_att_feats here is already projected
|
730 |
+
att_size = att_feats.numel() // att_feats.size(0) // att_feats.size(-1)
|
731 |
+
att = p_att_feats.view(-1, att_size, self.att_hid_size)
|
732 |
+
|
733 |
+
att_h = self.h2att(h) # batch * att_hid_size
|
734 |
+
att_h = att_h.unsqueeze(1).expand_as(att) # batch * att_size * att_hid_size
|
735 |
+
dot = att + att_h # batch * att_size * att_hid_size
|
736 |
+
dot = torch.tanh(dot) # batch * att_size * att_hid_size
|
737 |
+
dot = dot.view(-1, self.att_hid_size) # (batch * att_size) * att_hid_size
|
738 |
+
dot = self.alpha_net(dot) # (batch * att_size) * 1
|
739 |
+
dot = dot.view(-1, att_size) # batch * att_size
|
740 |
+
|
741 |
+
weight = F.softmax(dot, dim=1) # batch * att_size
|
742 |
+
if att_masks is not None:
|
743 |
+
weight = weight * att_masks.view(-1, att_size).to(weight)
|
744 |
+
weight = weight / weight.sum(1, keepdim=True) # normalize to 1
|
745 |
+
att_feats_ = att_feats.view(-1, att_size, att_feats.size(-1)) # batch * att_size * att_feat_size
|
746 |
+
att_res = torch.bmm(weight.unsqueeze(1), att_feats_).squeeze(1) # batch * att_feat_size
|
747 |
+
|
748 |
+
return att_res
|
749 |
+
|
750 |
+
class Att2in2Core(nn.Module):
|
751 |
+
def __init__(self, opt):
|
752 |
+
super(Att2in2Core, self).__init__()
|
753 |
+
self.input_encoding_size = opt.input_encoding_size
|
754 |
+
#self.rnn_type = opt.rnn_type
|
755 |
+
self.rnn_size = opt.rnn_size
|
756 |
+
#self.num_layers = opt.num_layers
|
757 |
+
self.drop_prob_lm = opt.drop_prob_lm
|
758 |
+
self.fc_feat_size = opt.fc_feat_size
|
759 |
+
self.att_feat_size = opt.att_feat_size
|
760 |
+
self.att_hid_size = opt.att_hid_size
|
761 |
+
|
762 |
+
# Build a LSTM
|
763 |
+
self.a2c = nn.Linear(self.rnn_size, 2 * self.rnn_size)
|
764 |
+
self.i2h = nn.Linear(self.input_encoding_size, 5 * self.rnn_size)
|
765 |
+
self.h2h = nn.Linear(self.rnn_size, 5 * self.rnn_size)
|
766 |
+
self.dropout = nn.Dropout(self.drop_prob_lm)
|
767 |
+
|
768 |
+
self.attention = Attention(opt)
|
769 |
+
|
770 |
+
def forward(self, xt, fc_feats, att_feats, p_att_feats, state, att_masks=None):
|
771 |
+
att_res = self.attention(state[0][-1], att_feats, p_att_feats, att_masks)
|
772 |
+
|
773 |
+
all_input_sums = self.i2h(xt) + self.h2h(state[0][-1])
|
774 |
+
sigmoid_chunk = all_input_sums.narrow(1, 0, 3 * self.rnn_size)
|
775 |
+
sigmoid_chunk = torch.sigmoid(sigmoid_chunk)
|
776 |
+
in_gate = sigmoid_chunk.narrow(1, 0, self.rnn_size)
|
777 |
+
forget_gate = sigmoid_chunk.narrow(1, self.rnn_size, self.rnn_size)
|
778 |
+
out_gate = sigmoid_chunk.narrow(1, self.rnn_size * 2, self.rnn_size)
|
779 |
+
|
780 |
+
in_transform = all_input_sums.narrow(1, 3 * self.rnn_size, 2 * self.rnn_size) + \
|
781 |
+
self.a2c(att_res)
|
782 |
+
in_transform = torch.max(\
|
783 |
+
in_transform.narrow(1, 0, self.rnn_size),
|
784 |
+
in_transform.narrow(1, self.rnn_size, self.rnn_size))
|
785 |
+
next_c = forget_gate * state[1][-1] + in_gate * in_transform
|
786 |
+
next_h = out_gate * torch.tanh(next_c)
|
787 |
+
|
788 |
+
output = self.dropout(next_h)
|
789 |
+
state = (next_h.unsqueeze(0), next_c.unsqueeze(0))
|
790 |
+
return output, state
|
791 |
+
|
792 |
+
class Att2inCore(Att2in2Core):
|
793 |
+
def __init__(self, opt):
|
794 |
+
super(Att2inCore, self).__init__(opt)
|
795 |
+
del self.a2c
|
796 |
+
self.a2c = nn.Linear(self.att_feat_size, 2 * self.rnn_size)
|
797 |
+
|
798 |
+
"""
|
799 |
+
Note this is my attempt to replicate att2all model in self-critical paper.
|
800 |
+
However, this is not a correct replication actually. Will fix it.
|
801 |
+
"""
|
802 |
+
class Att2all2Core(nn.Module):
|
803 |
+
def __init__(self, opt):
|
804 |
+
super(Att2all2Core, self).__init__()
|
805 |
+
self.input_encoding_size = opt.input_encoding_size
|
806 |
+
#self.rnn_type = opt.rnn_type
|
807 |
+
self.rnn_size = opt.rnn_size
|
808 |
+
#self.num_layers = opt.num_layers
|
809 |
+
self.drop_prob_lm = opt.drop_prob_lm
|
810 |
+
self.fc_feat_size = opt.fc_feat_size
|
811 |
+
self.att_feat_size = opt.att_feat_size
|
812 |
+
self.att_hid_size = opt.att_hid_size
|
813 |
+
|
814 |
+
# Build a LSTM
|
815 |
+
self.a2h = nn.Linear(self.rnn_size, 5 * self.rnn_size)
|
816 |
+
self.i2h = nn.Linear(self.input_encoding_size, 5 * self.rnn_size)
|
817 |
+
self.h2h = nn.Linear(self.rnn_size, 5 * self.rnn_size)
|
818 |
+
self.dropout = nn.Dropout(self.drop_prob_lm)
|
819 |
+
|
820 |
+
self.attention = Attention(opt)
|
821 |
+
|
822 |
+
def forward(self, xt, fc_feats, att_feats, p_att_feats, state, att_masks=None):
|
823 |
+
att_res = self.attention(state[0][-1], att_feats, p_att_feats, att_masks)
|
824 |
+
|
825 |
+
all_input_sums = self.i2h(xt) + self.h2h(state[0][-1]) + self.a2h(att_res)
|
826 |
+
sigmoid_chunk = all_input_sums.narrow(1, 0, 3 * self.rnn_size)
|
827 |
+
sigmoid_chunk = torch.sigmoid(sigmoid_chunk)
|
828 |
+
in_gate = sigmoid_chunk.narrow(1, 0, self.rnn_size)
|
829 |
+
forget_gate = sigmoid_chunk.narrow(1, self.rnn_size, self.rnn_size)
|
830 |
+
out_gate = sigmoid_chunk.narrow(1, self.rnn_size * 2, self.rnn_size)
|
831 |
+
|
832 |
+
in_transform = all_input_sums.narrow(1, 3 * self.rnn_size, 2 * self.rnn_size)
|
833 |
+
in_transform = torch.max(\
|
834 |
+
in_transform.narrow(1, 0, self.rnn_size),
|
835 |
+
in_transform.narrow(1, self.rnn_size, self.rnn_size))
|
836 |
+
next_c = forget_gate * state[1][-1] + in_gate * in_transform
|
837 |
+
next_h = out_gate * torch.tanh(next_c)
|
838 |
+
|
839 |
+
output = self.dropout(next_h)
|
840 |
+
state = (next_h.unsqueeze(0), next_c.unsqueeze(0))
|
841 |
+
return output, state
|
842 |
+
|
843 |
+
class AdaAttModel(AttModel):
|
844 |
+
def __init__(self, opt):
|
845 |
+
super(AdaAttModel, self).__init__(opt)
|
846 |
+
self.core = AdaAttCore(opt)
|
847 |
+
|
848 |
+
# AdaAtt with maxout lstm
|
849 |
+
class AdaAttMOModel(AttModel):
|
850 |
+
def __init__(self, opt):
|
851 |
+
super(AdaAttMOModel, self).__init__(opt)
|
852 |
+
self.core = AdaAttCore(opt, True)
|
853 |
+
|
854 |
+
class Att2in2Model(AttModel):
|
855 |
+
def __init__(self, opt):
|
856 |
+
super(Att2in2Model, self).__init__(opt)
|
857 |
+
self.core = Att2in2Core(opt)
|
858 |
+
delattr(self, 'fc_embed')
|
859 |
+
self.fc_embed = lambda x : x
|
860 |
+
|
861 |
+
class Att2all2Model(AttModel):
|
862 |
+
def __init__(self, opt):
|
863 |
+
super(Att2all2Model, self).__init__(opt)
|
864 |
+
self.core = Att2all2Core(opt)
|
865 |
+
delattr(self, 'fc_embed')
|
866 |
+
self.fc_embed = lambda x : x
|
867 |
+
|
868 |
+
class UpDownModel(AttModel):
|
869 |
+
def __init__(self, opt):
|
870 |
+
super(UpDownModel, self).__init__(opt)
|
871 |
+
self.num_layers = 2
|
872 |
+
self.core = UpDownCore(opt)
|
873 |
+
|
874 |
+
class StackAttModel(AttModel):
|
875 |
+
def __init__(self, opt):
|
876 |
+
super(StackAttModel, self).__init__(opt)
|
877 |
+
self.num_layers = 3
|
878 |
+
self.core = StackAttCore(opt)
|
879 |
+
|
880 |
+
class DenseAttModel(AttModel):
|
881 |
+
def __init__(self, opt):
|
882 |
+
super(DenseAttModel, self).__init__(opt)
|
883 |
+
self.num_layers = 3
|
884 |
+
self.core = DenseAttCore(opt)
|
885 |
+
|
886 |
+
class Att2inModel(AttModel):
|
887 |
+
def __init__(self, opt):
|
888 |
+
super(Att2inModel, self).__init__(opt)
|
889 |
+
del self.embed, self.fc_embed, self.att_embed
|
890 |
+
self.embed = nn.Embedding(self.vocab_size + 1, self.input_encoding_size)
|
891 |
+
self.fc_embed = self.att_embed = lambda x: x
|
892 |
+
del self.ctx2att
|
893 |
+
self.ctx2att = nn.Linear(self.att_feat_size, self.att_hid_size)
|
894 |
+
self.core = Att2inCore(opt)
|
895 |
+
self.init_weights()
|
896 |
+
|
897 |
+
def init_weights(self):
|
898 |
+
initrange = 0.1
|
899 |
+
self.embed.weight.data.uniform_(-initrange, initrange)
|
900 |
+
self.logit.bias.data.fill_(0)
|
901 |
+
self.logit.weight.data.uniform_(-initrange, initrange)
|
902 |
+
|
903 |
+
|
904 |
+
class NewFCModel(AttModel):
|
905 |
+
def __init__(self, opt):
|
906 |
+
super(NewFCModel, self).__init__(opt)
|
907 |
+
self.fc_embed = nn.Linear(self.fc_feat_size, self.input_encoding_size)
|
908 |
+
self.embed = nn.Embedding(self.vocab_size + 1, self.input_encoding_size)
|
909 |
+
self._core = LSTMCore(opt)
|
910 |
+
delattr(self, 'att_embed')
|
911 |
+
self.att_embed = lambda x : x
|
912 |
+
delattr(self, 'ctx2att')
|
913 |
+
self.ctx2att = lambda x: x
|
914 |
+
|
915 |
+
def core(self, xt, fc_feats, att_feats, p_att_feats, state, att_masks):
|
916 |
+
# Step 0, feed the input image
|
917 |
+
# if (self.training and state[0].is_leaf) or \
|
918 |
+
# (not self.training and state[0].sum() == 0):
|
919 |
+
# _, state = self._core(fc_feats, state)
|
920 |
+
# three cases
|
921 |
+
# normal mle training
|
922 |
+
# Sample
|
923 |
+
# beam search (diverse beam search)
|
924 |
+
# fixed captioning module.
|
925 |
+
is_first_step = (state[0]==0).all(2).all(0) # size: B
|
926 |
+
if is_first_step.all():
|
927 |
+
_, state = self._core(fc_feats, state)
|
928 |
+
elif is_first_step.any():
|
929 |
+
# This is mostly for diverse beam search I think
|
930 |
+
new_state = [torch.zeros_like(_) for _ in state]
|
931 |
+
new_state[0][:, ~is_first_step] = state[0][:, ~is_first_step]
|
932 |
+
new_state[1][:, ~is_first_step] = state[1][:, ~is_first_step]
|
933 |
+
_, state = self._core(fc_feats, state)
|
934 |
+
new_state[0][:, is_first_step] = state[0][:, is_first_step]
|
935 |
+
new_state[1][:, is_first_step] = state[1][:, is_first_step]
|
936 |
+
state = new_state
|
937 |
+
# if (state[0]==0).all():
|
938 |
+
# # Let's forget about diverse beam search first
|
939 |
+
# _, state = self._core(fc_feats, state)
|
940 |
+
return self._core(xt, state)
|
941 |
+
|
942 |
+
def _prepare_feature(self, fc_feats, att_feats, att_masks):
|
943 |
+
fc_feats = self.fc_embed(fc_feats)
|
944 |
+
|
945 |
+
return fc_feats, att_feats, att_feats, att_masks
|
946 |
+
|
947 |
+
|
948 |
+
class LMModel(AttModel):
|
949 |
+
def __init__(self, opt):
|
950 |
+
super(LMModel, self).__init__(opt)
|
951 |
+
delattr(self, 'fc_embed')
|
952 |
+
self.fc_embed = lambda x: x.new_zeros(x.shape[0], self.input_encoding_size)
|
953 |
+
self.embed = nn.Embedding(self.vocab_size + 1, self.input_encoding_size)
|
954 |
+
self._core = LSTMCore(opt)
|
955 |
+
delattr(self, 'att_embed')
|
956 |
+
self.att_embed = lambda x : x
|
957 |
+
delattr(self, 'ctx2att')
|
958 |
+
self.ctx2att = lambda x: x
|
959 |
+
|
960 |
+
def core(self, xt, fc_feats, att_feats, p_att_feats, state, att_masks):
|
961 |
+
if (state[0]==0).all():
|
962 |
+
# Let's forget about diverse beam search first
|
963 |
+
_, state = self._core(fc_feats, state)
|
964 |
+
return self._core(xt, state)
|
965 |
+
|
966 |
+
def _prepare_feature(self, fc_feats, att_feats, att_masks):
|
967 |
+
fc_feats = self.fc_embed(fc_feats)
|
968 |
+
|
969 |
+
return fc_feats, None, None, None
|
captioning/models/BertCapModel.py
ADDED
@@ -0,0 +1,104 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
BertCapModel is using huggingface transformer bert model as seq2seq model.
|
3 |
+
|
4 |
+
The result is not as goog as original transformer.
|
5 |
+
"""
|
6 |
+
|
7 |
+
from __future__ import absolute_import
|
8 |
+
from __future__ import division
|
9 |
+
from __future__ import print_function
|
10 |
+
|
11 |
+
import torch
|
12 |
+
import torch.nn as nn
|
13 |
+
import torch.nn.functional as F
|
14 |
+
|
15 |
+
import copy
|
16 |
+
import math
|
17 |
+
import numpy as np
|
18 |
+
|
19 |
+
from .CaptionModel import CaptionModel
|
20 |
+
from .AttModel import sort_pack_padded_sequence, pad_unsort_packed_sequence, pack_wrapper, AttModel
|
21 |
+
try:
|
22 |
+
from transformers import BertModel, BertConfig
|
23 |
+
except:
|
24 |
+
print('Hugginface transformers not installed; please visit https://github.com/huggingface/transformers')
|
25 |
+
from .TransformerModel import subsequent_mask, TransformerModel, Generator
|
26 |
+
|
27 |
+
class EncoderDecoder(nn.Module):
|
28 |
+
"""
|
29 |
+
A standard Encoder-Decoder architecture. Base for this and many
|
30 |
+
other models.
|
31 |
+
"""
|
32 |
+
def __init__(self, encoder, decoder, generator):
|
33 |
+
super(EncoderDecoder, self).__init__()
|
34 |
+
self.encoder = encoder
|
35 |
+
self.decoder = decoder
|
36 |
+
self.generator = generator
|
37 |
+
|
38 |
+
def forward(self, src, tgt, src_mask, tgt_mask):
|
39 |
+
"Take in and process masked src and target sequences."
|
40 |
+
return self.decode(self.encode(src, src_mask), src_mask,
|
41 |
+
tgt, tgt_mask)
|
42 |
+
|
43 |
+
def encode(self, src, src_mask):
|
44 |
+
return self.encoder(inputs_embeds=src,
|
45 |
+
attention_mask=src_mask)[0]
|
46 |
+
|
47 |
+
def decode(self, memory, src_mask, tgt, tgt_mask):
|
48 |
+
return self.decoder(input_ids=tgt,
|
49 |
+
attention_mask=tgt_mask,
|
50 |
+
encoder_hidden_states=memory,
|
51 |
+
encoder_attention_mask=src_mask)[0]
|
52 |
+
|
53 |
+
|
54 |
+
class BertCapModel(TransformerModel):
|
55 |
+
|
56 |
+
def make_model(self, src_vocab, tgt_vocab, N_enc=6, N_dec=6,
|
57 |
+
d_model=512, d_ff=2048, h=8, dropout=0.1):
|
58 |
+
"Helper: Construct a model from hyperparameters."
|
59 |
+
enc_config = BertConfig(vocab_size=1,
|
60 |
+
hidden_size=d_model,
|
61 |
+
num_hidden_layers=N_enc,
|
62 |
+
num_attention_heads=h,
|
63 |
+
intermediate_size=d_ff,
|
64 |
+
hidden_dropout_prob=dropout,
|
65 |
+
attention_probs_dropout_prob=dropout,
|
66 |
+
max_position_embeddings=1,
|
67 |
+
type_vocab_size=1)
|
68 |
+
dec_config = BertConfig(vocab_size=tgt_vocab,
|
69 |
+
hidden_size=d_model,
|
70 |
+
num_hidden_layers=N_dec,
|
71 |
+
num_attention_heads=h,
|
72 |
+
intermediate_size=d_ff,
|
73 |
+
hidden_dropout_prob=dropout,
|
74 |
+
attention_probs_dropout_prob=dropout,
|
75 |
+
max_position_embeddings=17,
|
76 |
+
type_vocab_size=1,
|
77 |
+
is_decoder=True)
|
78 |
+
encoder = BertModel(enc_config)
|
79 |
+
def return_embeds(*args, **kwargs):
|
80 |
+
return kwargs['inputs_embeds']
|
81 |
+
del encoder.embeddings; encoder.embeddings = return_embeds
|
82 |
+
decoder = BertModel(dec_config)
|
83 |
+
model = EncoderDecoder(
|
84 |
+
encoder,
|
85 |
+
decoder,
|
86 |
+
Generator(d_model, tgt_vocab))
|
87 |
+
return model
|
88 |
+
|
89 |
+
def __init__(self, opt):
|
90 |
+
super(BertCapModel, self).__init__(opt)
|
91 |
+
|
92 |
+
def core(self, it, fc_feats_ph, att_feats_ph, memory, state, mask):
|
93 |
+
"""
|
94 |
+
state = [ys.unsqueeze(0)]
|
95 |
+
"""
|
96 |
+
if len(state) == 0:
|
97 |
+
ys = it.unsqueeze(1)
|
98 |
+
else:
|
99 |
+
ys = torch.cat([state[0][0], it.unsqueeze(1)], dim=1)
|
100 |
+
out = self.model.decode(memory, mask,
|
101 |
+
ys,
|
102 |
+
subsequent_mask(ys.size(1))
|
103 |
+
.to(memory.device))
|
104 |
+
return out[:, -1], [ys.unsqueeze(0)]
|
captioning/models/CaptionModel.py
ADDED
@@ -0,0 +1,407 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# This file contains ShowAttendTell and AllImg model
|
2 |
+
|
3 |
+
# ShowAttendTell is from Show, Attend and Tell: Neural Image Caption Generation with Visual Attention
|
4 |
+
# https://arxiv.org/abs/1502.03044
|
5 |
+
|
6 |
+
# AllImg is a model where
|
7 |
+
# img feature is concatenated with word embedding at every time step as the input of lstm
|
8 |
+
from __future__ import absolute_import
|
9 |
+
from __future__ import division
|
10 |
+
from __future__ import print_function
|
11 |
+
|
12 |
+
import numpy as np
|
13 |
+
import torch
|
14 |
+
import torch.nn as nn
|
15 |
+
import torch.nn.functional as F
|
16 |
+
from torch.autograd import *
|
17 |
+
from ..utils import misc as utils
|
18 |
+
from . import utils as model_utils
|
19 |
+
|
20 |
+
|
21 |
+
class CaptionModel(nn.Module):
|
22 |
+
def __init__(self):
|
23 |
+
super(CaptionModel, self).__init__()
|
24 |
+
|
25 |
+
# implements beam search
|
26 |
+
# calls beam_step and returns the final set of beams
|
27 |
+
# augments log-probabilities with diversity terms when number of groups > 1
|
28 |
+
|
29 |
+
def forward(self, *args, **kwargs):
|
30 |
+
mode = kwargs.get('mode', 'forward')
|
31 |
+
if 'mode' in kwargs:
|
32 |
+
del kwargs['mode']
|
33 |
+
return getattr(self, '_'+mode)(*args, **kwargs)
|
34 |
+
|
35 |
+
def beam_search(self, init_state, init_logprobs, *args, **kwargs):
|
36 |
+
|
37 |
+
# function computes the similarity score to be augmented
|
38 |
+
def add_diversity(beam_seq_table, logprobs, t, divm, diversity_lambda, bdash):
|
39 |
+
local_time = t - divm
|
40 |
+
unaug_logprobs = logprobs.clone()
|
41 |
+
batch_size = beam_seq_table[0].shape[0]
|
42 |
+
|
43 |
+
if divm > 0:
|
44 |
+
change = logprobs.new_zeros(batch_size, logprobs.shape[-1])
|
45 |
+
for prev_choice in range(divm):
|
46 |
+
prev_decisions = beam_seq_table[prev_choice][:, :, local_time] # Nxb
|
47 |
+
for prev_labels in range(bdash):
|
48 |
+
change.scatter_add_(1, prev_decisions[:, prev_labels].unsqueeze(-1), change.new_ones(batch_size, 1))
|
49 |
+
|
50 |
+
if local_time == 0:
|
51 |
+
logprobs = logprobs - change * diversity_lambda
|
52 |
+
else:
|
53 |
+
logprobs = logprobs - self.repeat_tensor(bdash, change) * diversity_lambda
|
54 |
+
|
55 |
+
return logprobs, unaug_logprobs
|
56 |
+
|
57 |
+
|
58 |
+
# does one step of classical beam search
|
59 |
+
|
60 |
+
def beam_step(logprobs, unaug_logprobs, beam_size, t, beam_seq, beam_seq_logprobs, beam_logprobs_sum, state):
|
61 |
+
#INPUTS:
|
62 |
+
#logprobs: probabilities augmented after diversity N*bxV
|
63 |
+
#beam_size: obvious
|
64 |
+
#t : time instant
|
65 |
+
#beam_seq : tensor contanining the beams
|
66 |
+
#beam_seq_logprobs: tensor contanining the beam logprobs
|
67 |
+
#beam_logprobs_sum: tensor contanining joint logprobs
|
68 |
+
#OUPUTS:
|
69 |
+
#beam_seq : tensor containing the word indices of the decoded captions Nxbxl
|
70 |
+
#beam_seq_logprobs : log-probability of each decision made, NxbxlxV
|
71 |
+
#beam_logprobs_sum : joint log-probability of each beam Nxb
|
72 |
+
|
73 |
+
batch_size = beam_logprobs_sum.shape[0]
|
74 |
+
vocab_size = logprobs.shape[-1]
|
75 |
+
logprobs = logprobs.reshape(batch_size, -1, vocab_size) # NxbxV
|
76 |
+
if t == 0:
|
77 |
+
assert logprobs.shape[1] == 1
|
78 |
+
beam_logprobs_sum = beam_logprobs_sum[:, :1]
|
79 |
+
candidate_logprobs = beam_logprobs_sum.unsqueeze(-1) + logprobs # beam_logprobs_sum Nxb logprobs is NxbxV
|
80 |
+
ys, ix = torch.sort(candidate_logprobs.reshape(candidate_logprobs.shape[0], -1), -1, True)
|
81 |
+
ys, ix = ys[:,:beam_size], ix[:,:beam_size]
|
82 |
+
beam_ix = ix // vocab_size # Nxb which beam
|
83 |
+
selected_ix = ix % vocab_size # Nxb # which world
|
84 |
+
state_ix = (beam_ix + torch.arange(batch_size).type_as(beam_ix).unsqueeze(-1) * logprobs.shape[1]).reshape(-1) # N*b which in Nxb beams
|
85 |
+
|
86 |
+
|
87 |
+
if t > 0:
|
88 |
+
# gather according to beam_ix
|
89 |
+
assert (beam_seq.gather(1, beam_ix.unsqueeze(-1).expand_as(beam_seq)) == beam_seq.reshape(-1, beam_seq.shape[-1])[state_ix].view_as(beam_seq)).all()
|
90 |
+
beam_seq = beam_seq.gather(1, beam_ix.unsqueeze(-1).expand_as(beam_seq))
|
91 |
+
|
92 |
+
beam_seq_logprobs = beam_seq_logprobs.gather(1, beam_ix.unsqueeze(-1).unsqueeze(-1).expand_as(beam_seq_logprobs))
|
93 |
+
|
94 |
+
beam_seq = torch.cat([beam_seq, selected_ix.unsqueeze(-1)], -1) # beam_seq Nxbxl
|
95 |
+
beam_logprobs_sum = beam_logprobs_sum.gather(1, beam_ix) + \
|
96 |
+
logprobs.reshape(batch_size, -1).gather(1, ix)
|
97 |
+
assert (beam_logprobs_sum == ys).all()
|
98 |
+
_tmp_beam_logprobs = unaug_logprobs[state_ix].reshape(batch_size, -1, vocab_size)
|
99 |
+
beam_logprobs = unaug_logprobs.reshape(batch_size, -1, vocab_size).gather(1, beam_ix.unsqueeze(-1).expand(-1, -1, vocab_size)) # NxbxV
|
100 |
+
assert (_tmp_beam_logprobs == beam_logprobs).all()
|
101 |
+
beam_seq_logprobs = torch.cat([
|
102 |
+
beam_seq_logprobs,
|
103 |
+
beam_logprobs.reshape(batch_size, -1, 1, vocab_size)], 2)
|
104 |
+
|
105 |
+
new_state = [None for _ in state]
|
106 |
+
for _ix in range(len(new_state)):
|
107 |
+
# copy over state in previous beam q to new beam at vix
|
108 |
+
new_state[_ix] = state[_ix][:, state_ix]
|
109 |
+
state = new_state
|
110 |
+
return beam_seq,beam_seq_logprobs,beam_logprobs_sum,state
|
111 |
+
|
112 |
+
# Start diverse_beam_search
|
113 |
+
opt = kwargs['opt']
|
114 |
+
temperature = opt.get('temperature', 1) # This should not affect beam search, but will affect dbs
|
115 |
+
beam_size = opt.get('beam_size', 10)
|
116 |
+
group_size = opt.get('group_size', 1)
|
117 |
+
diversity_lambda = opt.get('diversity_lambda', 0.5)
|
118 |
+
decoding_constraint = opt.get('decoding_constraint', 0)
|
119 |
+
remove_bad_endings = opt.get('remove_bad_endings', 0)
|
120 |
+
suppress_UNK = opt.get('suppress_UNK', 0)
|
121 |
+
length_penalty = utils.penalty_builder(opt.get('length_penalty', ''))
|
122 |
+
bdash = beam_size // group_size # beam per group
|
123 |
+
|
124 |
+
batch_size = init_logprobs.shape[0]
|
125 |
+
device = init_logprobs.device
|
126 |
+
# INITIALIZATIONS
|
127 |
+
beam_seq_table = [torch.LongTensor(batch_size, bdash, 0).to(device) for _ in range(group_size)]
|
128 |
+
beam_seq_logprobs_table = [torch.FloatTensor(batch_size, bdash, 0, self.vocab_size + 1).to(device) for _ in range(group_size)]
|
129 |
+
beam_logprobs_sum_table = [torch.zeros(batch_size, bdash).to(device) for _ in range(group_size)]
|
130 |
+
|
131 |
+
# logprobs # logprobs predicted in last time step, shape (beam_size, vocab_size+1)
|
132 |
+
done_beams_table = [[[] for __ in range(group_size)] for _ in range(batch_size)]
|
133 |
+
# state_table = [list(torch.unbind(_)) for _ in torch.stack(init_state).chunk(group_size, 2)]
|
134 |
+
# state_table = list(zip(*[_.reshape(-1, batch_size * bdash, group_size, *_.shape[2:]).chunk(group_size, 2) for _ in init_state]))
|
135 |
+
state_table = [[_.clone() for _ in init_state] for _ in range(group_size)]
|
136 |
+
# logprobs_table = list(init_logprobs.reshape(batch_size * bdash, group_size, -1).chunk(group_size, 0))
|
137 |
+
logprobs_table = [init_logprobs.clone() for _ in range(group_size)]
|
138 |
+
# END INIT
|
139 |
+
|
140 |
+
# Chunk elements in the args
|
141 |
+
args = list(args)
|
142 |
+
args = model_utils.split_tensors(group_size, args) # For each arg, turn (Bbg)x... to (Bb)x(g)x...
|
143 |
+
if self.__class__.__name__ == 'AttEnsemble':
|
144 |
+
args = [[[args[j][i][k] for i in range(len(self.models))] for j in range(len(args))] for k in range(group_size)] # group_name, arg_name, model_name
|
145 |
+
else:
|
146 |
+
args = [[args[i][j] for i in range(len(args))] for j in range(group_size)]
|
147 |
+
|
148 |
+
for t in range(self.seq_length + group_size - 1):
|
149 |
+
for divm in range(group_size):
|
150 |
+
if t >= divm and t <= self.seq_length + divm - 1:
|
151 |
+
# add diversity
|
152 |
+
logprobs = logprobs_table[divm]
|
153 |
+
# suppress previous word
|
154 |
+
if decoding_constraint and t-divm > 0:
|
155 |
+
logprobs.scatter_(1, beam_seq_table[divm][:, :, t-divm-1].reshape(-1, 1).to(device), float('-inf'))
|
156 |
+
if remove_bad_endings and t-divm > 0:
|
157 |
+
logprobs[torch.from_numpy(np.isin(beam_seq_table[divm][:, :, t-divm-1].cpu().numpy(), self.bad_endings_ix)).reshape(-1), 0] = float('-inf')
|
158 |
+
# suppress UNK tokens in the decoding
|
159 |
+
if suppress_UNK and hasattr(self, 'vocab') and self.vocab[str(logprobs.size(1)-1)] == 'UNK':
|
160 |
+
logprobs[:,logprobs.size(1)-1] = logprobs[:, logprobs.size(1)-1] - 1000
|
161 |
+
# diversity is added here
|
162 |
+
# the function directly modifies the logprobs values and hence, we need to return
|
163 |
+
# the unaugmented ones for sorting the candidates in the end. # for historical
|
164 |
+
# reasons :-)
|
165 |
+
logprobs, unaug_logprobs = add_diversity(beam_seq_table,logprobs,t,divm,diversity_lambda,bdash)
|
166 |
+
|
167 |
+
# infer new beams
|
168 |
+
beam_seq_table[divm],\
|
169 |
+
beam_seq_logprobs_table[divm],\
|
170 |
+
beam_logprobs_sum_table[divm],\
|
171 |
+
state_table[divm] = beam_step(logprobs,
|
172 |
+
unaug_logprobs,
|
173 |
+
bdash,
|
174 |
+
t-divm,
|
175 |
+
beam_seq_table[divm],
|
176 |
+
beam_seq_logprobs_table[divm],
|
177 |
+
beam_logprobs_sum_table[divm],
|
178 |
+
state_table[divm])
|
179 |
+
|
180 |
+
# if time's up... or if end token is reached then copy beams
|
181 |
+
for b in range(batch_size):
|
182 |
+
is_end = beam_seq_table[divm][b, :, t-divm] == self.eos_idx
|
183 |
+
assert beam_seq_table[divm].shape[-1] == t-divm+1
|
184 |
+
if t == self.seq_length + divm - 1:
|
185 |
+
is_end.fill_(1)
|
186 |
+
for vix in range(bdash):
|
187 |
+
if is_end[vix]:
|
188 |
+
final_beam = {
|
189 |
+
'seq': beam_seq_table[divm][b, vix].clone(),
|
190 |
+
'logps': beam_seq_logprobs_table[divm][b, vix].clone(),
|
191 |
+
'unaug_p': beam_seq_logprobs_table[divm][b, vix].sum().item(),
|
192 |
+
'p': beam_logprobs_sum_table[divm][b, vix].item()
|
193 |
+
}
|
194 |
+
final_beam['p'] = length_penalty(t-divm+1, final_beam['p'])
|
195 |
+
done_beams_table[b][divm].append(final_beam)
|
196 |
+
beam_logprobs_sum_table[divm][b, is_end] -= 1000
|
197 |
+
|
198 |
+
# move the current group one step forward in time
|
199 |
+
|
200 |
+
it = beam_seq_table[divm][:, :, t-divm].reshape(-1).to(logprobs.device)
|
201 |
+
logprobs_table[divm], state_table[divm] = self.get_logprobs_state(it, *(args[divm] + [state_table[divm]]))
|
202 |
+
logprobs_table[divm] = F.log_softmax(logprobs_table[divm] / temperature, dim=-1)
|
203 |
+
|
204 |
+
# all beams are sorted by their log-probabilities
|
205 |
+
done_beams_table = [[sorted(done_beams_table[b][i], key=lambda x: -x['p'])[:bdash] for i in range(group_size)] for b in range(batch_size)]
|
206 |
+
done_beams = [sum(_, []) for _ in done_beams_table]
|
207 |
+
return done_beams
|
208 |
+
|
209 |
+
def old_beam_search(self, init_state, init_logprobs, *args, **kwargs):
|
210 |
+
|
211 |
+
# function computes the similarity score to be augmented
|
212 |
+
def add_diversity(beam_seq_table, logprobsf, t, divm, diversity_lambda, bdash):
|
213 |
+
local_time = t - divm
|
214 |
+
unaug_logprobsf = logprobsf.clone()
|
215 |
+
for prev_choice in range(divm):
|
216 |
+
prev_decisions = beam_seq_table[prev_choice][local_time]
|
217 |
+
for sub_beam in range(bdash):
|
218 |
+
for prev_labels in range(bdash):
|
219 |
+
logprobsf[sub_beam][prev_decisions[prev_labels]] = logprobsf[sub_beam][prev_decisions[prev_labels]] - diversity_lambda
|
220 |
+
return unaug_logprobsf
|
221 |
+
|
222 |
+
# does one step of classical beam search
|
223 |
+
|
224 |
+
def beam_step(logprobsf, unaug_logprobsf, beam_size, t, beam_seq, beam_seq_logprobs, beam_logprobs_sum, state):
|
225 |
+
#INPUTS:
|
226 |
+
#logprobsf: probabilities augmented after diversity
|
227 |
+
#beam_size: obvious
|
228 |
+
#t : time instant
|
229 |
+
#beam_seq : tensor contanining the beams
|
230 |
+
#beam_seq_logprobs: tensor contanining the beam logprobs
|
231 |
+
#beam_logprobs_sum: tensor contanining joint logprobs
|
232 |
+
#OUPUTS:
|
233 |
+
#beam_seq : tensor containing the word indices of the decoded captions
|
234 |
+
#beam_seq_logprobs : log-probability of each decision made, same size as beam_seq
|
235 |
+
#beam_logprobs_sum : joint log-probability of each beam
|
236 |
+
|
237 |
+
ys,ix = torch.sort(logprobsf,1,True)
|
238 |
+
candidates = []
|
239 |
+
cols = min(beam_size, ys.size(1))
|
240 |
+
rows = beam_size
|
241 |
+
if t == 0:
|
242 |
+
rows = 1
|
243 |
+
for c in range(cols): # for each column (word, essentially)
|
244 |
+
for q in range(rows): # for each beam expansion
|
245 |
+
#compute logprob of expanding beam q with word in (sorted) position c
|
246 |
+
local_logprob = ys[q,c].item()
|
247 |
+
candidate_logprob = beam_logprobs_sum[q] + local_logprob
|
248 |
+
# local_unaug_logprob = unaug_logprobsf[q,ix[q,c]]
|
249 |
+
candidates.append({'c':ix[q,c], 'q':q, 'p':candidate_logprob, 'r':unaug_logprobsf[q]})
|
250 |
+
candidates = sorted(candidates, key=lambda x: -x['p'])
|
251 |
+
|
252 |
+
new_state = [_.clone() for _ in state]
|
253 |
+
#beam_seq_prev, beam_seq_logprobs_prev
|
254 |
+
if t >= 1:
|
255 |
+
#we''ll need these as reference when we fork beams around
|
256 |
+
beam_seq_prev = beam_seq[:t].clone()
|
257 |
+
beam_seq_logprobs_prev = beam_seq_logprobs[:t].clone()
|
258 |
+
for vix in range(beam_size):
|
259 |
+
v = candidates[vix]
|
260 |
+
#fork beam index q into index vix
|
261 |
+
if t >= 1:
|
262 |
+
beam_seq[:t, vix] = beam_seq_prev[:, v['q']]
|
263 |
+
beam_seq_logprobs[:t, vix] = beam_seq_logprobs_prev[:, v['q']]
|
264 |
+
#rearrange recurrent states
|
265 |
+
for state_ix in range(len(new_state)):
|
266 |
+
# copy over state in previous beam q to new beam at vix
|
267 |
+
new_state[state_ix][:, vix] = state[state_ix][:, v['q']] # dimension one is time step
|
268 |
+
#append new end terminal at the end of this beam
|
269 |
+
beam_seq[t, vix] = v['c'] # c'th word is the continuation
|
270 |
+
beam_seq_logprobs[t, vix] = v['r'] # the raw logprob here
|
271 |
+
beam_logprobs_sum[vix] = v['p'] # the new (sum) logprob along this beam
|
272 |
+
state = new_state
|
273 |
+
return beam_seq,beam_seq_logprobs,beam_logprobs_sum,state,candidates
|
274 |
+
|
275 |
+
# Start diverse_beam_search
|
276 |
+
opt = kwargs['opt']
|
277 |
+
temperature = opt.get('temperature', 1) # This should not affect beam search, but will affect dbs
|
278 |
+
beam_size = opt.get('beam_size', 10)
|
279 |
+
group_size = opt.get('group_size', 1)
|
280 |
+
diversity_lambda = opt.get('diversity_lambda', 0.5)
|
281 |
+
decoding_constraint = opt.get('decoding_constraint', 0)
|
282 |
+
remove_bad_endings = opt.get('remove_bad_endings', 0)
|
283 |
+
suppress_UNK = opt.get('suppress_UNK', 0)
|
284 |
+
length_penalty = utils.penalty_builder(opt.get('length_penalty', ''))
|
285 |
+
bdash = beam_size // group_size # beam per group
|
286 |
+
|
287 |
+
# INITIALIZATIONS
|
288 |
+
beam_seq_table = [torch.LongTensor(self.seq_length, bdash).zero_() for _ in range(group_size)]
|
289 |
+
beam_seq_logprobs_table = [torch.FloatTensor(self.seq_length, bdash, self.vocab_size + 1).zero_() for _ in range(group_size)]
|
290 |
+
beam_logprobs_sum_table = [torch.zeros(bdash) for _ in range(group_size)]
|
291 |
+
|
292 |
+
# logprobs # logprobs predicted in last time step, shape (beam_size, vocab_size+1)
|
293 |
+
done_beams_table = [[] for _ in range(group_size)]
|
294 |
+
# state_table = [list(torch.unbind(_)) for _ in torch.stack(init_state).chunk(group_size, 2)]
|
295 |
+
state_table = list(zip(*[_.chunk(group_size, 1) for _ in init_state]))
|
296 |
+
logprobs_table = list(init_logprobs.chunk(group_size, 0))
|
297 |
+
# END INIT
|
298 |
+
|
299 |
+
# Chunk elements in the args
|
300 |
+
args = list(args)
|
301 |
+
if self.__class__.__name__ == 'AttEnsemble':
|
302 |
+
args = [[_.chunk(group_size) if _ is not None else [None]*group_size for _ in args_] for args_ in args] # arg_name, model_name, group_name
|
303 |
+
args = [[[args[j][i][k] for i in range(len(self.models))] for j in range(len(args))] for k in range(group_size)] # group_name, arg_name, model_name
|
304 |
+
else:
|
305 |
+
args = [_.chunk(group_size) if _ is not None else [None]*group_size for _ in args]
|
306 |
+
args = [[args[i][j] for i in range(len(args))] for j in range(group_size)]
|
307 |
+
|
308 |
+
for t in range(self.seq_length + group_size - 1):
|
309 |
+
for divm in range(group_size):
|
310 |
+
if t >= divm and t <= self.seq_length + divm - 1:
|
311 |
+
# add diversity
|
312 |
+
logprobsf = logprobs_table[divm]
|
313 |
+
# suppress previous word
|
314 |
+
if decoding_constraint and t-divm > 0:
|
315 |
+
logprobsf.scatter_(1, beam_seq_table[divm][t-divm-1].unsqueeze(1).to(logprobsf.device), float('-inf'))
|
316 |
+
if remove_bad_endings and t-divm > 0:
|
317 |
+
logprobsf[torch.from_numpy(np.isin(beam_seq_table[divm][t-divm-1].cpu().numpy(), self.bad_endings_ix)), 0] = float('-inf')
|
318 |
+
# suppress UNK tokens in the decoding
|
319 |
+
if suppress_UNK and hasattr(self, 'vocab') and self.vocab[str(logprobsf.size(1)-1)] == 'UNK':
|
320 |
+
logprobsf[:,logprobsf.size(1)-1] = logprobsf[:, logprobsf.size(1)-1] - 1000
|
321 |
+
# diversity is added here
|
322 |
+
# the function directly modifies the logprobsf values and hence, we need to return
|
323 |
+
# the unaugmented ones for sorting the candidates in the end. # for historical
|
324 |
+
# reasons :-)
|
325 |
+
unaug_logprobsf = add_diversity(beam_seq_table,logprobsf,t,divm,diversity_lambda,bdash)
|
326 |
+
|
327 |
+
# infer new beams
|
328 |
+
beam_seq_table[divm],\
|
329 |
+
beam_seq_logprobs_table[divm],\
|
330 |
+
beam_logprobs_sum_table[divm],\
|
331 |
+
state_table[divm],\
|
332 |
+
candidates_divm = beam_step(logprobsf,
|
333 |
+
unaug_logprobsf,
|
334 |
+
bdash,
|
335 |
+
t-divm,
|
336 |
+
beam_seq_table[divm],
|
337 |
+
beam_seq_logprobs_table[divm],
|
338 |
+
beam_logprobs_sum_table[divm],
|
339 |
+
state_table[divm])
|
340 |
+
|
341 |
+
# if time's up... or if end token is reached then copy beams
|
342 |
+
for vix in range(bdash):
|
343 |
+
if beam_seq_table[divm][t-divm,vix] == self.eos_idx or t == self.seq_length + divm - 1:
|
344 |
+
final_beam = {
|
345 |
+
'seq': beam_seq_table[divm][:, vix].clone(),
|
346 |
+
'logps': beam_seq_logprobs_table[divm][:, vix].clone(),
|
347 |
+
'unaug_p': beam_seq_logprobs_table[divm][:, vix].sum().item(),
|
348 |
+
'p': beam_logprobs_sum_table[divm][vix].item()
|
349 |
+
}
|
350 |
+
final_beam['p'] = length_penalty(t-divm+1, final_beam['p'])
|
351 |
+
done_beams_table[divm].append(final_beam)
|
352 |
+
# don't continue beams from finished sequences
|
353 |
+
beam_logprobs_sum_table[divm][vix] = -1000
|
354 |
+
|
355 |
+
# move the current group one step forward in time
|
356 |
+
|
357 |
+
it = beam_seq_table[divm][t-divm].to(logprobsf.device)
|
358 |
+
logprobs_table[divm], state_table[divm] = self.get_logprobs_state(it, *(args[divm] + [state_table[divm]]))
|
359 |
+
logprobs_table[divm] = F.log_softmax(logprobs_table[divm] / temperature, dim=-1)
|
360 |
+
|
361 |
+
# all beams are sorted by their log-probabilities
|
362 |
+
done_beams_table = [sorted(done_beams_table[i], key=lambda x: -x['p'])[:bdash] for i in range(group_size)]
|
363 |
+
done_beams = sum(done_beams_table, [])
|
364 |
+
return done_beams
|
365 |
+
|
366 |
+
def sample_next_word(self, logprobs, sample_method, temperature):
|
367 |
+
if sample_method == 'greedy':
|
368 |
+
sampleLogprobs, it = torch.max(logprobs.data, 1)
|
369 |
+
it = it.view(-1).long()
|
370 |
+
elif sample_method == 'gumbel': # gumbel softmax
|
371 |
+
# ref: https://gist.github.com/yzh119/fd2146d2aeb329d067568a493b20172f
|
372 |
+
def sample_gumbel(shape, eps=1e-20):
|
373 |
+
U = torch.rand(shape).to(logprobs.device)
|
374 |
+
return -torch.log(-torch.log(U + eps) + eps)
|
375 |
+
def gumbel_softmax_sample(logits, temperature):
|
376 |
+
y = logits + sample_gumbel(logits.size())
|
377 |
+
return F.log_softmax(y / temperature, dim=-1)
|
378 |
+
_logprobs = gumbel_softmax_sample(logprobs, temperature)
|
379 |
+
_, it = torch.max(_logprobs.data, 1)
|
380 |
+
sampleLogprobs = logprobs.gather(1, it.unsqueeze(1)) # gather the logprobs at sampled positions
|
381 |
+
else:
|
382 |
+
logprobs = logprobs / temperature
|
383 |
+
if sample_method.startswith('top'): # topk sampling
|
384 |
+
top_num = float(sample_method[3:])
|
385 |
+
if 0 < top_num < 1:
|
386 |
+
# nucleus sampling from # The Curious Case of Neural Text Degeneration
|
387 |
+
probs = F.softmax(logprobs, dim=1)
|
388 |
+
sorted_probs, sorted_indices = torch.sort(probs, descending=True, dim=1)
|
389 |
+
_cumsum = sorted_probs.cumsum(1)
|
390 |
+
mask = _cumsum < top_num
|
391 |
+
mask = torch.cat([torch.ones_like(mask[:,:1]), mask[:,:-1]], 1)
|
392 |
+
sorted_probs = sorted_probs * mask.to(sorted_probs)
|
393 |
+
sorted_probs = sorted_probs / sorted_probs.sum(1, keepdim=True)
|
394 |
+
logprobs.scatter_(1, sorted_indices, sorted_probs.log())
|
395 |
+
else:
|
396 |
+
the_k = int(top_num)
|
397 |
+
tmp = torch.empty_like(logprobs).fill_(float('-inf'))
|
398 |
+
topk, indices = torch.topk(logprobs, the_k, dim=1)
|
399 |
+
tmp = tmp.scatter(1, indices, topk)
|
400 |
+
logprobs = tmp
|
401 |
+
it = torch.distributions.Categorical(logits=logprobs.detach()).sample()
|
402 |
+
sampleLogprobs = logprobs.gather(1, it.unsqueeze(1)) # gather the logprobs at sampled positions
|
403 |
+
return it, sampleLogprobs
|
404 |
+
|
405 |
+
|
406 |
+
def decode_sequence(self, seq):
|
407 |
+
return utils.decode_sequence(self.vocab, seq)
|
captioning/models/FCModel.py
ADDED
@@ -0,0 +1,204 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import absolute_import
|
2 |
+
from __future__ import division
|
3 |
+
from __future__ import print_function
|
4 |
+
|
5 |
+
import torch
|
6 |
+
import torch.nn as nn
|
7 |
+
import torch.nn.functional as F
|
8 |
+
from torch.autograd import *
|
9 |
+
from . import utils
|
10 |
+
|
11 |
+
from .CaptionModel import CaptionModel
|
12 |
+
|
13 |
+
class LSTMCore(nn.Module):
|
14 |
+
def __init__(self, opt):
|
15 |
+
super(LSTMCore, self).__init__()
|
16 |
+
self.input_encoding_size = opt.input_encoding_size
|
17 |
+
self.rnn_size = opt.rnn_size
|
18 |
+
self.drop_prob_lm = opt.drop_prob_lm
|
19 |
+
|
20 |
+
# Build a LSTM
|
21 |
+
self.i2h = nn.Linear(self.input_encoding_size, 5 * self.rnn_size)
|
22 |
+
self.h2h = nn.Linear(self.rnn_size, 5 * self.rnn_size)
|
23 |
+
self.dropout = nn.Dropout(self.drop_prob_lm)
|
24 |
+
|
25 |
+
def forward(self, xt, state):
|
26 |
+
|
27 |
+
all_input_sums = self.i2h(xt) + self.h2h(state[0][-1])
|
28 |
+
sigmoid_chunk = all_input_sums.narrow(1, 0, 3 * self.rnn_size)
|
29 |
+
sigmoid_chunk = torch.sigmoid(sigmoid_chunk)
|
30 |
+
in_gate = sigmoid_chunk.narrow(1, 0, self.rnn_size)
|
31 |
+
forget_gate = sigmoid_chunk.narrow(1, self.rnn_size, self.rnn_size)
|
32 |
+
out_gate = sigmoid_chunk.narrow(1, self.rnn_size * 2, self.rnn_size)
|
33 |
+
|
34 |
+
in_transform = torch.max(\
|
35 |
+
all_input_sums.narrow(1, 3 * self.rnn_size, self.rnn_size),
|
36 |
+
all_input_sums.narrow(1, 4 * self.rnn_size, self.rnn_size))
|
37 |
+
next_c = forget_gate * state[1][-1] + in_gate * in_transform
|
38 |
+
next_h = out_gate * torch.tanh(next_c)
|
39 |
+
|
40 |
+
output = self.dropout(next_h)
|
41 |
+
state = (next_h.unsqueeze(0), next_c.unsqueeze(0))
|
42 |
+
return output, state
|
43 |
+
|
44 |
+
class FCModel(CaptionModel):
|
45 |
+
def __init__(self, opt):
|
46 |
+
super(FCModel, self).__init__()
|
47 |
+
self.vocab_size = opt.vocab_size
|
48 |
+
self.input_encoding_size = opt.input_encoding_size
|
49 |
+
self.rnn_type = opt.rnn_type
|
50 |
+
self.rnn_size = opt.rnn_size
|
51 |
+
self.num_layers = opt.num_layers
|
52 |
+
self.drop_prob_lm = opt.drop_prob_lm
|
53 |
+
self.seq_length = opt.seq_length
|
54 |
+
self.fc_feat_size = opt.fc_feat_size
|
55 |
+
|
56 |
+
self.ss_prob = 0.0 # Schedule sampling probability
|
57 |
+
|
58 |
+
self.img_embed = nn.Linear(self.fc_feat_size, self.input_encoding_size)
|
59 |
+
self.core = LSTMCore(opt)
|
60 |
+
self.embed = nn.Embedding(self.vocab_size + 1, self.input_encoding_size)
|
61 |
+
self.logit = nn.Linear(self.rnn_size, self.vocab_size + 1)
|
62 |
+
|
63 |
+
self.init_weights()
|
64 |
+
|
65 |
+
def init_weights(self):
|
66 |
+
initrange = 0.1
|
67 |
+
self.embed.weight.data.uniform_(-initrange, initrange)
|
68 |
+
self.logit.bias.data.fill_(0)
|
69 |
+
self.logit.weight.data.uniform_(-initrange, initrange)
|
70 |
+
|
71 |
+
def init_hidden(self, bsz):
|
72 |
+
weight = self.logit.weight
|
73 |
+
if self.rnn_type == 'lstm':
|
74 |
+
return (weight.new_zeros(self.num_layers, bsz, self.rnn_size),
|
75 |
+
weight.new_zeros(self.num_layers, bsz, self.rnn_size))
|
76 |
+
else:
|
77 |
+
return weight.new_zeros(self.num_layers, bsz, self.rnn_size)
|
78 |
+
|
79 |
+
def _forward(self, fc_feats, att_feats, seq, att_masks=None):
|
80 |
+
batch_size = fc_feats.size(0)
|
81 |
+
seq_per_img = seq.shape[0] // batch_size
|
82 |
+
state = self.init_hidden(batch_size*seq_per_img)
|
83 |
+
outputs = []
|
84 |
+
|
85 |
+
if seq_per_img > 1:
|
86 |
+
fc_feats = utils.repeat_tensors(seq_per_img, fc_feats)
|
87 |
+
|
88 |
+
for i in range(seq.size(1) + 1):
|
89 |
+
if i == 0:
|
90 |
+
xt = self.img_embed(fc_feats)
|
91 |
+
else:
|
92 |
+
if self.training and i >= 2 and self.ss_prob > 0.0: # otherwiste no need to sample
|
93 |
+
sample_prob = fc_feats.data.new(batch_size*seq_per_img).uniform_(0, 1)
|
94 |
+
sample_mask = sample_prob < self.ss_prob
|
95 |
+
if sample_mask.sum() == 0:
|
96 |
+
it = seq[:, i-1].clone()
|
97 |
+
else:
|
98 |
+
sample_ind = sample_mask.nonzero().view(-1)
|
99 |
+
it = seq[:, i-1].data.clone()
|
100 |
+
#prob_prev = torch.exp(outputs[-1].data.index_select(0, sample_ind)) # fetch prev distribution: shape Nx(M+1)
|
101 |
+
#it.index_copy_(0, sample_ind, torch.multinomial(prob_prev, 1).view(-1))
|
102 |
+
prob_prev = torch.exp(outputs[-1].data) # fetch prev distribution: shape Nx(M+1)
|
103 |
+
it.index_copy_(0, sample_ind, torch.multinomial(prob_prev, 1).view(-1).index_select(0, sample_ind))
|
104 |
+
else:
|
105 |
+
it = seq[:, i-1].clone()
|
106 |
+
# break if all the sequences end
|
107 |
+
if i >= 2 and seq[:, i-1].sum() == 0:
|
108 |
+
break
|
109 |
+
xt = self.embed(it)
|
110 |
+
|
111 |
+
output, state = self.core(xt, state)
|
112 |
+
output = F.log_softmax(self.logit(output), dim=1)
|
113 |
+
outputs.append(output)
|
114 |
+
|
115 |
+
return torch.cat([_.unsqueeze(1) for _ in outputs[1:]], 1).contiguous()
|
116 |
+
|
117 |
+
def get_logprobs_state(self, it, state):
|
118 |
+
# 'it' is contains a word index
|
119 |
+
xt = self.embed(it)
|
120 |
+
|
121 |
+
output, state = self.core(xt, state)
|
122 |
+
logprobs = F.log_softmax(self.logit(output), dim=1)
|
123 |
+
|
124 |
+
return logprobs, state
|
125 |
+
|
126 |
+
def _sample_beam(self, fc_feats, att_feats, att_masks=None, opt={}):
|
127 |
+
beam_size = opt.get('beam_size', 10)
|
128 |
+
batch_size = fc_feats.size(0)
|
129 |
+
|
130 |
+
assert beam_size <= self.vocab_size + 1, 'lets assume this for now, otherwise this corner case causes a few headaches down the road. can be dealt with in future if needed'
|
131 |
+
seq = torch.LongTensor(self.seq_length, batch_size).zero_()
|
132 |
+
seqLogprobs = torch.FloatTensor(self.seq_length, batch_size, self.vocab_size + 1)
|
133 |
+
# lets process every image independently for now, for simplicity
|
134 |
+
|
135 |
+
self.done_beams = [[] for _ in range(batch_size)]
|
136 |
+
for k in range(batch_size):
|
137 |
+
state = self.init_hidden(beam_size)
|
138 |
+
for t in range(2):
|
139 |
+
if t == 0:
|
140 |
+
xt = self.img_embed(fc_feats[k:k+1]).expand(beam_size, self.input_encoding_size)
|
141 |
+
elif t == 1: # input <bos>
|
142 |
+
it = fc_feats.data.new(beam_size).long().zero_()
|
143 |
+
xt = self.embed(it)
|
144 |
+
|
145 |
+
output, state = self.core(xt, state)
|
146 |
+
logprobs = F.log_softmax(self.logit(output), dim=1)
|
147 |
+
|
148 |
+
self.done_beams[k] = self.beam_search(state, logprobs, opt=opt)
|
149 |
+
seq[:, k] = self.done_beams[k][0]['seq'] # the first beam has highest cumulative score
|
150 |
+
seqLogprobs[:, k] = self.done_beams[k][0]['logps']
|
151 |
+
# return the samples and their log likelihoods
|
152 |
+
return seq.transpose(0, 1), seqLogprobs.transpose(0, 1)
|
153 |
+
|
154 |
+
def _sample(self, fc_feats, att_feats, att_masks=None, opt={}):
|
155 |
+
sample_method = opt.get('sample_method', 'greedy')
|
156 |
+
beam_size = opt.get('beam_size', 1)
|
157 |
+
temperature = opt.get('temperature', 1.0)
|
158 |
+
if beam_size > 1 and sample_method in ['greedy', 'beam_search']:
|
159 |
+
return self._sample_beam(fc_feats, att_feats, opt)
|
160 |
+
|
161 |
+
batch_size = fc_feats.size(0)
|
162 |
+
state = self.init_hidden(batch_size)
|
163 |
+
seq = fc_feats.new_zeros(batch_size, self.seq_length, dtype=torch.long)
|
164 |
+
seqLogprobs = fc_feats.new_zeros(batch_size, self.seq_length, self.vocab_size + 1)
|
165 |
+
for t in range(self.seq_length + 2):
|
166 |
+
if t == 0:
|
167 |
+
xt = self.img_embed(fc_feats)
|
168 |
+
else:
|
169 |
+
if t == 1: # input <bos>
|
170 |
+
it = fc_feats.data.new(batch_size).long().zero_()
|
171 |
+
xt = self.embed(it)
|
172 |
+
|
173 |
+
output, state = self.core(xt, state)
|
174 |
+
logprobs = F.log_softmax(self.logit(output), dim=1)
|
175 |
+
|
176 |
+
# sample the next_word
|
177 |
+
if t == self.seq_length + 1: # skip if we achieve maximum length
|
178 |
+
break
|
179 |
+
if sample_method == 'greedy':
|
180 |
+
sampleLogprobs, it = torch.max(logprobs.data, 1)
|
181 |
+
it = it.view(-1).long()
|
182 |
+
else:
|
183 |
+
if temperature == 1.0:
|
184 |
+
prob_prev = torch.exp(logprobs.data).cpu() # fetch prev distribution: shape Nx(M+1)
|
185 |
+
else:
|
186 |
+
# scale logprobs by temperature
|
187 |
+
prob_prev = torch.exp(torch.div(logprobs.data, temperature)).cpu()
|
188 |
+
it = torch.multinomial(prob_prev, 1).to(logprobs.device)
|
189 |
+
sampleLogprobs = logprobs.gather(1, it) # gather the logprobs at sampled positions
|
190 |
+
it = it.view(-1).long() # and flatten indices for downstream processing
|
191 |
+
|
192 |
+
if t >= 1:
|
193 |
+
# stop when all finished
|
194 |
+
if t == 1:
|
195 |
+
unfinished = it > 0
|
196 |
+
else:
|
197 |
+
unfinished = unfinished & (it > 0)
|
198 |
+
it = it * unfinished.type_as(it)
|
199 |
+
seq[:,t-1] = it #seq[t] the input of t+2 time step
|
200 |
+
seqLogprobs[:,t-1] = sampleLogprobs.view(-1)
|
201 |
+
if unfinished.sum() == 0:
|
202 |
+
break
|
203 |
+
|
204 |
+
return seq, seqLogprobs
|
captioning/models/M2Transformer.py
ADDED
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Instruction to use meshed_memory_transformer (https://arxiv.org/abs/1912.08226)
|
3 |
+
|
4 |
+
pip install git+https://github.com/ruotianluo/meshed-memory-transformer.git
|
5 |
+
|
6 |
+
Note:
|
7 |
+
Currently m2transformer is not performing as well as original transformer. Not sure why? Still investigating.
|
8 |
+
"""
|
9 |
+
|
10 |
+
from __future__ import absolute_import
|
11 |
+
from __future__ import division
|
12 |
+
from __future__ import print_function
|
13 |
+
|
14 |
+
import torch
|
15 |
+
import torch.nn as nn
|
16 |
+
import torch.nn.functional as F
|
17 |
+
|
18 |
+
import copy
|
19 |
+
import math
|
20 |
+
import numpy as np
|
21 |
+
|
22 |
+
from .CaptionModel import CaptionModel
|
23 |
+
from .AttModel import sort_pack_padded_sequence, pad_unsort_packed_sequence, pack_wrapper, AttModel
|
24 |
+
|
25 |
+
try:
|
26 |
+
from m2transformer.models.transformer import Transformer, MemoryAugmentedEncoder, MeshedDecoder, ScaledDotProductAttentionMemory
|
27 |
+
except:
|
28 |
+
print('meshed-memory-transformer not installed; please run `pip install git+https://github.com/ruotianluo/meshed-memory-transformer.git`')
|
29 |
+
from .TransformerModel import subsequent_mask, TransformerModel
|
30 |
+
|
31 |
+
|
32 |
+
class M2TransformerModel(TransformerModel):
|
33 |
+
|
34 |
+
def make_model(self, src_vocab, tgt_vocab, N_enc=6, N_dec=6,
|
35 |
+
d_model=512, d_ff=2048, h=8, dropout=0.1):
|
36 |
+
"Helper: Construct a model from hyperparameters."
|
37 |
+
encoder = MemoryAugmentedEncoder(N_enc, 0, attention_module=ScaledDotProductAttentionMemory,
|
38 |
+
attention_module_kwargs={'m': 40})
|
39 |
+
# Another implementation is to use MultiLevelEncoder + att_embed
|
40 |
+
decoder = MeshedDecoder(tgt_vocab, 54, N_dec, -1) # -1 is padding;
|
41 |
+
model = Transformer(0, encoder, decoder) # 0 is bos
|
42 |
+
return model
|
43 |
+
|
44 |
+
def __init__(self, opt):
|
45 |
+
super(M2TransformerModel, self).__init__(opt)
|
46 |
+
delattr(self, 'att_embed')
|
47 |
+
self.att_embed = lambda x: x # The visual embed is in the MAEncoder
|
48 |
+
# Notes: The dropout in MAEncoder is different from my att_embed, mine is 0.5?
|
49 |
+
# Also the attention mask seems wrong in MAEncoder too...intersting
|
50 |
+
|
51 |
+
def logit(self, x): # unsafe way
|
52 |
+
return x # M2transformer always output logsoftmax
|
53 |
+
|
54 |
+
def _prepare_feature(self, fc_feats, att_feats, att_masks):
|
55 |
+
|
56 |
+
att_feats, seq, att_masks, seq_mask = self._prepare_feature_forward(att_feats, att_masks)
|
57 |
+
memory, att_masks = self.model.encoder(att_feats)
|
58 |
+
|
59 |
+
return fc_feats[...,:0], att_feats[...,:0], memory, att_masks
|
60 |
+
|
61 |
+
def _forward(self, fc_feats, att_feats, seq, att_masks=None):
|
62 |
+
if seq.ndim == 3: # B * seq_per_img * seq_len
|
63 |
+
seq = seq.reshape(-1, seq.shape[2])
|
64 |
+
att_feats, seq, att_masks, seq_mask = self._prepare_feature_forward(att_feats, att_masks, seq)
|
65 |
+
|
66 |
+
seq = seq.clone()
|
67 |
+
seq[~seq_mask.any(-2)] = -1 # Make padding to be -1 (my dataloader uses 0 as padding)
|
68 |
+
outputs = self.model(att_feats, seq)
|
69 |
+
|
70 |
+
return outputs
|
71 |
+
|
72 |
+
def core(self, it, fc_feats_ph, att_feats_ph, memory, state, mask):
|
73 |
+
"""
|
74 |
+
state = [ys.unsqueeze(0)]
|
75 |
+
"""
|
76 |
+
if len(state) == 0:
|
77 |
+
ys = it.unsqueeze(1)
|
78 |
+
else:
|
79 |
+
ys = torch.cat([state[0][0], it.unsqueeze(1)], dim=1)
|
80 |
+
out = self.model.decoder(ys, memory, mask)
|
81 |
+
return out[:, -1], [ys.unsqueeze(0)]
|
82 |
+
|
83 |
+
def _sample_beam(self, fc_feats, att_feats, att_masks=None, opt={}):
|
84 |
+
beam_size = opt.get('beam_size', 10)
|
85 |
+
group_size = opt.get('group_size', 1)
|
86 |
+
sample_n = opt.get('sample_n', 10)
|
87 |
+
assert sample_n == 1 or sample_n == beam_size // group_size, 'when beam search, sample_n == 1 or beam search'
|
88 |
+
|
89 |
+
att_feats, _, __, ___ = self._prepare_feature_forward(att_feats, att_masks)
|
90 |
+
seq, logprobs, seqLogprobs = self.model.beam_search(att_feats, self.seq_length, 0,
|
91 |
+
beam_size, return_probs=True, out_size=beam_size)
|
92 |
+
seq = seq.reshape(-1, *seq.shape[2:])
|
93 |
+
seqLogprobs = seqLogprobs.reshape(-1, *seqLogprobs.shape[2:])
|
94 |
+
|
95 |
+
# if not (seqLogprobs.gather(-1, seq.unsqueeze(-1)).squeeze(-1) == logprobs.reshape(-1, logprobs.shape[-1])).all():
|
96 |
+
# import pudb;pu.db
|
97 |
+
# seqLogprobs = logprobs.reshape(-1, logprobs.shape[-1]).unsqueeze(-1).expand(-1,-1,seqLogprobs.shape[-1])
|
98 |
+
return seq, seqLogprobs
|
captioning/models/ShowTellModel.py
ADDED
@@ -0,0 +1,174 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import absolute_import
|
2 |
+
from __future__ import division
|
3 |
+
from __future__ import print_function
|
4 |
+
|
5 |
+
import torch
|
6 |
+
import torch.nn as nn
|
7 |
+
import torch.nn.functional as F
|
8 |
+
from torch.autograd import *
|
9 |
+
from . import utils
|
10 |
+
|
11 |
+
from .CaptionModel import CaptionModel
|
12 |
+
|
13 |
+
class ShowTellModel(CaptionModel):
|
14 |
+
def __init__(self, opt):
|
15 |
+
super(ShowTellModel, self).__init__()
|
16 |
+
self.vocab_size = opt.vocab_size
|
17 |
+
self.input_encoding_size = opt.input_encoding_size
|
18 |
+
self.rnn_type = opt.rnn_type
|
19 |
+
self.rnn_size = opt.rnn_size
|
20 |
+
self.num_layers = opt.num_layers
|
21 |
+
self.drop_prob_lm = opt.drop_prob_lm
|
22 |
+
self.seq_length = opt.seq_length
|
23 |
+
self.fc_feat_size = opt.fc_feat_size
|
24 |
+
|
25 |
+
self.ss_prob = 0.0 # Schedule sampling probability
|
26 |
+
|
27 |
+
self.img_embed = nn.Linear(self.fc_feat_size, self.input_encoding_size)
|
28 |
+
self.core = getattr(nn, self.rnn_type.upper())(self.input_encoding_size, self.rnn_size, self.num_layers, bias=False, dropout=self.drop_prob_lm)
|
29 |
+
self.embed = nn.Embedding(self.vocab_size + 1, self.input_encoding_size)
|
30 |
+
self.logit = nn.Linear(self.rnn_size, self.vocab_size + 1)
|
31 |
+
self.dropout = nn.Dropout(self.drop_prob_lm)
|
32 |
+
|
33 |
+
self.init_weights()
|
34 |
+
|
35 |
+
def init_weights(self):
|
36 |
+
initrange = 0.1
|
37 |
+
self.embed.weight.data.uniform_(-initrange, initrange)
|
38 |
+
self.logit.bias.data.fill_(0)
|
39 |
+
self.logit.weight.data.uniform_(-initrange, initrange)
|
40 |
+
|
41 |
+
def init_hidden(self, bsz):
|
42 |
+
weight = self.logit.weight
|
43 |
+
if self.rnn_type == 'lstm':
|
44 |
+
return (weight.new_zeros(self.num_layers, bsz, self.rnn_size),
|
45 |
+
weight.new_zeros(self.num_layers, bsz, self.rnn_size))
|
46 |
+
else:
|
47 |
+
return weight.new_zeros(self.num_layers, bsz, self.rnn_size)
|
48 |
+
|
49 |
+
def _forward(self, fc_feats, att_feats, seq, att_masks=None):
|
50 |
+
batch_size = fc_feats.size(0)
|
51 |
+
seq_per_img = seq.shape[0] // batch_size
|
52 |
+
state = self.init_hidden(batch_size*seq_per_img)
|
53 |
+
outputs = []
|
54 |
+
|
55 |
+
if seq_per_img > 1:
|
56 |
+
fc_feats = utils.repeat_tensors(seq_per_img, fc_feats)
|
57 |
+
|
58 |
+
for i in range(seq.size(1) + 1):
|
59 |
+
if i == 0:
|
60 |
+
xt = self.img_embed(fc_feats)
|
61 |
+
else:
|
62 |
+
if self.training and i >= 2 and self.ss_prob > 0.0: # otherwiste no need to sample
|
63 |
+
sample_prob = fc_feats.data.new(batch_size*seq_per_img).uniform_(0, 1)
|
64 |
+
sample_mask = sample_prob < self.ss_prob
|
65 |
+
if sample_mask.sum() == 0:
|
66 |
+
it = seq[:, i-1].clone()
|
67 |
+
else:
|
68 |
+
sample_ind = sample_mask.nonzero().view(-1)
|
69 |
+
it = seq[:, i-1].data.clone()
|
70 |
+
#prob_prev = torch.exp(outputs[-1].data.index_select(0, sample_ind)) # fetch prev distribution: shape Nx(M+1)
|
71 |
+
#it.index_copy_(0, sample_ind, torch.multinomial(prob_prev, 1).view(-1))
|
72 |
+
prob_prev = torch.exp(outputs[-1].data) # fetch prev distribution: shape Nx(M+1)
|
73 |
+
it.index_copy_(0, sample_ind, torch.multinomial(prob_prev, 1).view(-1).index_select(0, sample_ind))
|
74 |
+
else:
|
75 |
+
it = seq[:, i-1].clone()
|
76 |
+
# break if all the sequences end
|
77 |
+
if i >= 2 and seq[:, i-1].data.sum() == 0:
|
78 |
+
break
|
79 |
+
xt = self.embed(it)
|
80 |
+
|
81 |
+
output, state = self.core(xt.unsqueeze(0), state)
|
82 |
+
output = F.log_softmax(self.logit(self.dropout(output.squeeze(0))), dim=1)
|
83 |
+
outputs.append(output)
|
84 |
+
|
85 |
+
return torch.cat([_.unsqueeze(1) for _ in outputs[1:]], 1).contiguous()
|
86 |
+
|
87 |
+
def get_logprobs_state(self, it, state):
|
88 |
+
# 'it' contains a word index
|
89 |
+
xt = self.embed(it)
|
90 |
+
|
91 |
+
output, state = self.core(xt.unsqueeze(0), state)
|
92 |
+
logprobs = F.log_softmax(self.logit(self.dropout(output.squeeze(0))), dim=1)
|
93 |
+
|
94 |
+
return logprobs, state
|
95 |
+
|
96 |
+
def _sample_beam(self, fc_feats, att_feats, att_masks=None, opt={}):
|
97 |
+
beam_size = opt.get('beam_size', 10)
|
98 |
+
batch_size = fc_feats.size(0)
|
99 |
+
|
100 |
+
assert beam_size <= self.vocab_size + 1, 'lets assume this for now, otherwise this corner case causes a few headaches down the road. can be dealt with in future if needed'
|
101 |
+
seq = torch.LongTensor(self.seq_length, batch_size).zero_()
|
102 |
+
seqLogprobs = torch.FloatTensor(self.seq_length, batch_size)
|
103 |
+
# lets process every image independently for now, for simplicity
|
104 |
+
|
105 |
+
self.done_beams = [[] for _ in range(batch_size)]
|
106 |
+
for k in range(batch_size):
|
107 |
+
state = self.init_hidden(beam_size)
|
108 |
+
for t in range(2):
|
109 |
+
if t == 0:
|
110 |
+
xt = self.img_embed(fc_feats[k:k+1]).expand(beam_size, self.input_encoding_size)
|
111 |
+
elif t == 1: # input <bos>
|
112 |
+
it = fc_feats.data.new(beam_size).long().zero_()
|
113 |
+
xt = self.embed(it)
|
114 |
+
|
115 |
+
output, state = self.core(xt.unsqueeze(0), state)
|
116 |
+
logprobs = F.log_softmax(self.logit(self.dropout(output.squeeze(0))), dim=1)
|
117 |
+
|
118 |
+
self.done_beams[k] = self.beam_search(state, logprobs, opt=opt)
|
119 |
+
seq[:, k] = self.done_beams[k][0]['seq'] # the first beam has highest cumulative score
|
120 |
+
seqLogprobs[:, k] = self.done_beams[k][0]['logps']
|
121 |
+
# return the samples and their log likelihoods
|
122 |
+
return seq.transpose(0, 1), seqLogprobs.transpose(0, 1)
|
123 |
+
|
124 |
+
def _sample(self, fc_feats, att_feats, att_masks=None, opt={}):
|
125 |
+
sample_method = opt.get('sample_method', 'greedy')
|
126 |
+
beam_size = opt.get('beam_size', 1)
|
127 |
+
temperature = opt.get('temperature', 1.0)
|
128 |
+
if beam_size > 1 and sample_method in ['greedy', 'beam_search']:
|
129 |
+
return self.sample_beam(fc_feats, att_feats, opt)
|
130 |
+
|
131 |
+
batch_size = fc_feats.size(0)
|
132 |
+
state = self.init_hidden(batch_size)
|
133 |
+
seq = fc_feats.new_zeros(batch_size, self.seq_length, dtype=torch.long)
|
134 |
+
seqLogprobs = fc_feats.new_zeros(batch_size, self.seq_length)
|
135 |
+
for t in range(self.seq_length + 2):
|
136 |
+
if t == 0:
|
137 |
+
xt = self.img_embed(fc_feats)
|
138 |
+
else:
|
139 |
+
if t == 1: # input <bos>
|
140 |
+
it = fc_feats.data.new(batch_size).long().zero_()
|
141 |
+
xt = self.embed(it)
|
142 |
+
|
143 |
+
output, state = self.core(xt.unsqueeze(0), state)
|
144 |
+
logprobs = F.log_softmax(self.logit(self.dropout(output.squeeze(0))), dim=1)
|
145 |
+
|
146 |
+
# sample the next word
|
147 |
+
if t == self.seq_length + 1: # skip if we achieve maximum length
|
148 |
+
break
|
149 |
+
if sample_method == 'greedy':
|
150 |
+
sampleLogprobs, it = torch.max(logprobs.data, 1)
|
151 |
+
it = it.view(-1).long()
|
152 |
+
else:
|
153 |
+
if temperature == 1.0:
|
154 |
+
prob_prev = torch.exp(logprobs.data).cpu() # fetch prev distribution: shape Nx(M+1)
|
155 |
+
else:
|
156 |
+
# scale logprobs by temperature
|
157 |
+
prob_prev = torch.exp(torch.div(logprobs.data, temperature)).cpu()
|
158 |
+
it = torch.multinomial(prob_prev, 1).to(logprobs.device)
|
159 |
+
sampleLogprobs = logprobs.gather(1, it) # gather the logprobs at sampled positions
|
160 |
+
it = it.view(-1).long() # and flatten indices for downstream processing
|
161 |
+
|
162 |
+
if t >= 1:
|
163 |
+
# stop when all finished
|
164 |
+
if t == 1:
|
165 |
+
unfinished = it > 0
|
166 |
+
else:
|
167 |
+
unfinished = unfinished & (it > 0)
|
168 |
+
it = it * unfinished.type_as(it)
|
169 |
+
seq[:,t-1] = it #seq[t] the input of t+2 time step
|
170 |
+
seqLogprobs[:,t-1] = sampleLogprobs.view(-1)
|
171 |
+
if unfinished.sum() == 0:
|
172 |
+
break
|
173 |
+
|
174 |
+
return seq, seqLogprobs
|
captioning/models/TransformerModel.py
ADDED
@@ -0,0 +1,363 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# This file contains Transformer network
|
2 |
+
# Most of the code is copied from http://nlp.seas.harvard.edu/2018/04/03/attention.html
|
3 |
+
|
4 |
+
# The cfg name correspondance:
|
5 |
+
# N=num_layers
|
6 |
+
# d_model=input_encoding_size
|
7 |
+
# d_ff=rnn_size
|
8 |
+
# h is always 8
|
9 |
+
|
10 |
+
from __future__ import absolute_import
|
11 |
+
from __future__ import division
|
12 |
+
from __future__ import print_function
|
13 |
+
|
14 |
+
import torch
|
15 |
+
import torch.nn as nn
|
16 |
+
import torch.nn.functional as F
|
17 |
+
from . import utils
|
18 |
+
|
19 |
+
import copy
|
20 |
+
import math
|
21 |
+
import numpy as np
|
22 |
+
|
23 |
+
from .CaptionModel import CaptionModel
|
24 |
+
from .AttModel import sort_pack_padded_sequence, pad_unsort_packed_sequence, pack_wrapper, AttModel
|
25 |
+
|
26 |
+
class EncoderDecoder(nn.Module):
|
27 |
+
"""
|
28 |
+
A standard Encoder-Decoder architecture. Base for this and many
|
29 |
+
other models.
|
30 |
+
"""
|
31 |
+
def __init__(self, encoder, decoder, src_embed, tgt_embed, generator):
|
32 |
+
super(EncoderDecoder, self).__init__()
|
33 |
+
self.encoder = encoder
|
34 |
+
self.decoder = decoder
|
35 |
+
self.src_embed = src_embed
|
36 |
+
self.tgt_embed = tgt_embed
|
37 |
+
self.generator = generator
|
38 |
+
|
39 |
+
def forward(self, src, tgt, src_mask, tgt_mask):
|
40 |
+
"Take in and process masked src and target sequences."
|
41 |
+
return self.decode(self.encode(src, src_mask), src_mask,
|
42 |
+
tgt, tgt_mask)
|
43 |
+
|
44 |
+
def encode(self, src, src_mask):
|
45 |
+
return self.encoder(self.src_embed(src), src_mask)
|
46 |
+
|
47 |
+
def decode(self, memory, src_mask, tgt, tgt_mask):
|
48 |
+
return self.decoder(self.tgt_embed(tgt), memory, src_mask, tgt_mask)
|
49 |
+
|
50 |
+
class Generator(nn.Module):
|
51 |
+
"Define standard linear + softmax generation step."
|
52 |
+
def __init__(self, d_model, vocab):
|
53 |
+
super(Generator, self).__init__()
|
54 |
+
self.proj = nn.Linear(d_model, vocab)
|
55 |
+
|
56 |
+
def forward(self, x):
|
57 |
+
return F.log_softmax(self.proj(x), dim=-1)
|
58 |
+
|
59 |
+
def clones(module, N):
|
60 |
+
"Produce N identical layers."
|
61 |
+
return nn.ModuleList([copy.deepcopy(module) for _ in range(N)])
|
62 |
+
|
63 |
+
class Encoder(nn.Module):
|
64 |
+
"Core encoder is a stack of N layers"
|
65 |
+
def __init__(self, layer, N):
|
66 |
+
super(Encoder, self).__init__()
|
67 |
+
self.layers = clones(layer, N)
|
68 |
+
self.norm = LayerNorm(layer.size)
|
69 |
+
|
70 |
+
def forward(self, x, mask):
|
71 |
+
"Pass the input (and mask) through each layer in turn."
|
72 |
+
for layer in self.layers:
|
73 |
+
x = layer(x, mask)
|
74 |
+
return self.norm(x)
|
75 |
+
|
76 |
+
class LayerNorm(nn.Module):
|
77 |
+
"Construct a layernorm module (See citation for details)."
|
78 |
+
def __init__(self, features, eps=1e-6):
|
79 |
+
super(LayerNorm, self).__init__()
|
80 |
+
self.a_2 = nn.Parameter(torch.ones(features))
|
81 |
+
self.b_2 = nn.Parameter(torch.zeros(features))
|
82 |
+
self.eps = eps
|
83 |
+
|
84 |
+
def forward(self, x):
|
85 |
+
mean = x.mean(-1, keepdim=True)
|
86 |
+
std = x.std(-1, keepdim=True)
|
87 |
+
return self.a_2 * (x - mean) / (std + self.eps) + self.b_2
|
88 |
+
|
89 |
+
class SublayerConnection(nn.Module):
|
90 |
+
"""
|
91 |
+
A residual connection followed by a layer norm.
|
92 |
+
Note for code simplicity the norm is first as opposed to last.
|
93 |
+
"""
|
94 |
+
def __init__(self, size, dropout):
|
95 |
+
super(SublayerConnection, self).__init__()
|
96 |
+
self.norm = LayerNorm(size)
|
97 |
+
self.dropout = nn.Dropout(dropout)
|
98 |
+
|
99 |
+
def forward(self, x, sublayer):
|
100 |
+
"Apply residual connection to any sublayer with the same size."
|
101 |
+
return x + self.dropout(sublayer(self.norm(x)))
|
102 |
+
|
103 |
+
class EncoderLayer(nn.Module):
|
104 |
+
"Encoder is made up of self-attn and feed forward (defined below)"
|
105 |
+
def __init__(self, size, self_attn, feed_forward, dropout):
|
106 |
+
super(EncoderLayer, self).__init__()
|
107 |
+
self.self_attn = self_attn
|
108 |
+
self.feed_forward = feed_forward
|
109 |
+
self.sublayer = clones(SublayerConnection(size, dropout), 2)
|
110 |
+
self.size = size
|
111 |
+
|
112 |
+
def forward(self, x, mask):
|
113 |
+
"Follow Figure 1 (left) for connections."
|
114 |
+
x = self.sublayer[0](x, lambda x: self.self_attn(x, x, x, mask))
|
115 |
+
return self.sublayer[1](x, self.feed_forward)
|
116 |
+
|
117 |
+
class Decoder(nn.Module):
|
118 |
+
"Generic N layer decoder with masking."
|
119 |
+
def __init__(self, layer, N):
|
120 |
+
super(Decoder, self).__init__()
|
121 |
+
self.layers = clones(layer, N)
|
122 |
+
self.norm = LayerNorm(layer.size)
|
123 |
+
|
124 |
+
def forward(self, x, memory, src_mask, tgt_mask):
|
125 |
+
for layer in self.layers:
|
126 |
+
x = layer(x, memory, src_mask, tgt_mask)
|
127 |
+
return self.norm(x)
|
128 |
+
|
129 |
+
class DecoderLayer(nn.Module):
|
130 |
+
"Decoder is made of self-attn, src-attn, and feed forward (defined below)"
|
131 |
+
def __init__(self, size, self_attn, src_attn, feed_forward, dropout):
|
132 |
+
super(DecoderLayer, self).__init__()
|
133 |
+
self.size = size
|
134 |
+
self.self_attn = self_attn
|
135 |
+
self.src_attn = src_attn
|
136 |
+
self.feed_forward = feed_forward
|
137 |
+
self.sublayer = clones(SublayerConnection(size, dropout), 3)
|
138 |
+
|
139 |
+
def forward(self, x, memory, src_mask, tgt_mask):
|
140 |
+
"Follow Figure 1 (right) for connections."
|
141 |
+
m = memory
|
142 |
+
x = self.sublayer[0](x, lambda x: self.self_attn(x, x, x, tgt_mask))
|
143 |
+
x = self.sublayer[1](x, lambda x: self.src_attn(x, m, m, src_mask))
|
144 |
+
return self.sublayer[2](x, self.feed_forward)
|
145 |
+
|
146 |
+
def subsequent_mask(size):
|
147 |
+
"Mask out subsequent positions."
|
148 |
+
attn_shape = (1, size, size)
|
149 |
+
subsequent_mask = np.triu(np.ones(attn_shape), k=1).astype('uint8')
|
150 |
+
return torch.from_numpy(subsequent_mask) == 0
|
151 |
+
|
152 |
+
def attention(query, key, value, mask=None, dropout=None):
|
153 |
+
"Compute 'Scaled Dot Product Attention'"
|
154 |
+
d_k = query.size(-1)
|
155 |
+
scores = torch.matmul(query, key.transpose(-2, -1)) \
|
156 |
+
/ math.sqrt(d_k)
|
157 |
+
if mask is not None:
|
158 |
+
scores = scores.masked_fill(mask == 0, float('-inf'))
|
159 |
+
p_attn = F.softmax(scores, dim = -1)
|
160 |
+
if dropout is not None:
|
161 |
+
p_attn = dropout(p_attn)
|
162 |
+
return torch.matmul(p_attn, value), p_attn
|
163 |
+
|
164 |
+
class MultiHeadedAttention(nn.Module):
|
165 |
+
def __init__(self, h, d_model, dropout=0.1):
|
166 |
+
"Take in model size and number of heads."
|
167 |
+
super(MultiHeadedAttention, self).__init__()
|
168 |
+
assert d_model % h == 0
|
169 |
+
# We assume d_v always equals d_k
|
170 |
+
self.d_k = d_model // h
|
171 |
+
self.h = h
|
172 |
+
self.linears = clones(nn.Linear(d_model, d_model), 4)
|
173 |
+
self.attn = None
|
174 |
+
self.dropout = nn.Dropout(p=dropout)
|
175 |
+
|
176 |
+
def forward(self, query, key, value, mask=None):
|
177 |
+
"Implements Figure 2"
|
178 |
+
if mask is not None:
|
179 |
+
# Same mask applied to all h heads.
|
180 |
+
mask = mask.unsqueeze(1)
|
181 |
+
nbatches = query.size(0)
|
182 |
+
|
183 |
+
# 1) Do all the linear projections in batch from d_model => h x d_k
|
184 |
+
query, key, value = \
|
185 |
+
[l(x).view(nbatches, -1, self.h, self.d_k).transpose(1, 2)
|
186 |
+
for l, x in zip(self.linears, (query, key, value))]
|
187 |
+
|
188 |
+
# 2) Apply attention on all the projected vectors in batch.
|
189 |
+
x, self.attn = attention(query, key, value, mask=mask,
|
190 |
+
dropout=self.dropout)
|
191 |
+
|
192 |
+
# 3) "Concat" using a view and apply a final linear.
|
193 |
+
x = x.transpose(1, 2).contiguous() \
|
194 |
+
.view(nbatches, -1, self.h * self.d_k)
|
195 |
+
return self.linears[-1](x)
|
196 |
+
|
197 |
+
class PositionwiseFeedForward(nn.Module):
|
198 |
+
"Implements FFN equation."
|
199 |
+
def __init__(self, d_model, d_ff, dropout=0.1):
|
200 |
+
super(PositionwiseFeedForward, self).__init__()
|
201 |
+
self.w_1 = nn.Linear(d_model, d_ff)
|
202 |
+
self.w_2 = nn.Linear(d_ff, d_model)
|
203 |
+
self.dropout = nn.Dropout(dropout)
|
204 |
+
|
205 |
+
def forward(self, x):
|
206 |
+
return self.w_2(self.dropout(F.relu(self.w_1(x))))
|
207 |
+
|
208 |
+
class Embeddings(nn.Module):
|
209 |
+
def __init__(self, d_model, vocab):
|
210 |
+
super(Embeddings, self).__init__()
|
211 |
+
self.lut = nn.Embedding(vocab, d_model)
|
212 |
+
self.d_model = d_model
|
213 |
+
|
214 |
+
def forward(self, x):
|
215 |
+
return self.lut(x) * math.sqrt(self.d_model)
|
216 |
+
|
217 |
+
class PositionalEncoding(nn.Module):
|
218 |
+
"Implement the PE function."
|
219 |
+
def __init__(self, d_model, dropout, max_len=5000):
|
220 |
+
super(PositionalEncoding, self).__init__()
|
221 |
+
self.dropout = nn.Dropout(p=dropout)
|
222 |
+
|
223 |
+
# Compute the positional encodings once in log space.
|
224 |
+
pe = torch.zeros(max_len, d_model)
|
225 |
+
position = torch.arange(0, max_len).unsqueeze(1).float()
|
226 |
+
div_term = torch.exp(torch.arange(0, d_model, 2).float() *
|
227 |
+
-(math.log(10000.0) / d_model))
|
228 |
+
pe[:, 0::2] = torch.sin(position * div_term)
|
229 |
+
pe[:, 1::2] = torch.cos(position * div_term)
|
230 |
+
pe = pe.unsqueeze(0)
|
231 |
+
self.register_buffer('pe', pe)
|
232 |
+
|
233 |
+
def forward(self, x):
|
234 |
+
x = x + self.pe[:, :x.size(1)]
|
235 |
+
return self.dropout(x)
|
236 |
+
|
237 |
+
class TransformerModel(AttModel):
|
238 |
+
|
239 |
+
def make_model(self, src_vocab, tgt_vocab, N_enc=6, N_dec=6,
|
240 |
+
d_model=512, d_ff=2048, h=8, dropout=0.1):
|
241 |
+
"Helper: Construct a model from hyperparameters."
|
242 |
+
c = copy.deepcopy
|
243 |
+
attn = MultiHeadedAttention(h, d_model, dropout)
|
244 |
+
ff = PositionwiseFeedForward(d_model, d_ff, dropout)
|
245 |
+
position = PositionalEncoding(d_model, dropout)
|
246 |
+
model = EncoderDecoder(
|
247 |
+
Encoder(EncoderLayer(d_model, c(attn), c(ff), dropout), N_enc),
|
248 |
+
Decoder(DecoderLayer(d_model, c(attn), c(attn),
|
249 |
+
c(ff), dropout), N_dec),
|
250 |
+
lambda x:x, # nn.Sequential(Embeddings(d_model, src_vocab), c(position)),
|
251 |
+
nn.Sequential(Embeddings(d_model, tgt_vocab), c(position)),
|
252 |
+
Generator(d_model, tgt_vocab))
|
253 |
+
|
254 |
+
# This was important from their code.
|
255 |
+
# Initialize parameters with Glorot / fan_avg.
|
256 |
+
for p in model.parameters():
|
257 |
+
if p.dim() > 1:
|
258 |
+
nn.init.xavier_uniform_(p)
|
259 |
+
return model
|
260 |
+
|
261 |
+
def __init__(self, opt):
|
262 |
+
super(TransformerModel, self).__init__(opt)
|
263 |
+
self.opt = opt
|
264 |
+
# self.config = yaml.load(open(opt.config_file))
|
265 |
+
|
266 |
+
self.N_enc = getattr(opt, 'N_enc', opt.num_layers)
|
267 |
+
self.N_dec = getattr(opt, 'N_dec', opt.num_layers)
|
268 |
+
self.d_model = getattr(opt, 'd_model', opt.input_encoding_size)
|
269 |
+
self.d_ff = getattr(opt, 'd_ff', opt.rnn_size)
|
270 |
+
self.h = getattr(opt, 'num_att_heads', 8)
|
271 |
+
self.dropout = getattr(opt, 'dropout', 0.1)
|
272 |
+
|
273 |
+
delattr(self, 'att_embed')
|
274 |
+
self.att_embed = nn.Sequential(*(
|
275 |
+
((nn.BatchNorm1d(self.att_feat_size),) if self.use_bn else ())+
|
276 |
+
(nn.Linear(self.att_feat_size, self.d_model),
|
277 |
+
nn.ReLU(),
|
278 |
+
nn.Dropout(self.drop_prob_lm))+
|
279 |
+
((nn.BatchNorm1d(self.d_model),) if self.use_bn==2 else ())))
|
280 |
+
|
281 |
+
delattr(self, 'embed')
|
282 |
+
self.embed = lambda x : x
|
283 |
+
delattr(self, 'fc_embed')
|
284 |
+
self.fc_embed = lambda x : x
|
285 |
+
delattr(self, 'logit')
|
286 |
+
del self.ctx2att
|
287 |
+
|
288 |
+
tgt_vocab = self.vocab_size + 1
|
289 |
+
|
290 |
+
|
291 |
+
self.model = self.make_model(0, tgt_vocab,
|
292 |
+
N_enc=self.N_enc,
|
293 |
+
N_dec=self.N_dec,
|
294 |
+
d_model=self.d_model,
|
295 |
+
d_ff=self.d_ff,
|
296 |
+
h=self.h,
|
297 |
+
dropout=self.dropout)
|
298 |
+
|
299 |
+
def logit(self, x): # unsafe way
|
300 |
+
return self.model.generator.proj(x)
|
301 |
+
|
302 |
+
def init_hidden(self, bsz):
|
303 |
+
return []
|
304 |
+
|
305 |
+
def _prepare_feature(self, fc_feats, att_feats, att_masks):
|
306 |
+
|
307 |
+
att_feats, seq, att_masks, seq_mask = self._prepare_feature_forward(att_feats, att_masks)
|
308 |
+
memory = self.model.encode(att_feats, att_masks)
|
309 |
+
|
310 |
+
return fc_feats[...,:0], att_feats[...,:0], memory, att_masks
|
311 |
+
|
312 |
+
def _prepare_feature_forward(self, att_feats, att_masks=None, seq=None):
|
313 |
+
att_feats, att_masks = self.clip_att(att_feats, att_masks)
|
314 |
+
|
315 |
+
att_feats = pack_wrapper(self.att_embed, att_feats, att_masks)
|
316 |
+
|
317 |
+
if att_masks is None:
|
318 |
+
att_masks = att_feats.new_ones(att_feats.shape[:2], dtype=torch.long)
|
319 |
+
att_masks = att_masks.unsqueeze(-2)
|
320 |
+
|
321 |
+
if seq is not None:
|
322 |
+
# crop the last one
|
323 |
+
# seq = seq[:,:-1]
|
324 |
+
seq_mask = (seq.data != self.eos_idx) & (seq.data != self.pad_idx)
|
325 |
+
seq_mask[:,0] = 1 # bos
|
326 |
+
|
327 |
+
seq_mask = seq_mask.unsqueeze(-2)
|
328 |
+
seq_mask = seq_mask & subsequent_mask(seq.size(-1)).to(seq_mask)
|
329 |
+
|
330 |
+
seq_per_img = seq.shape[0] // att_feats.shape[0]
|
331 |
+
if seq_per_img > 1:
|
332 |
+
att_feats, att_masks = utils.repeat_tensors(seq_per_img,
|
333 |
+
[att_feats, att_masks]
|
334 |
+
)
|
335 |
+
else:
|
336 |
+
seq_mask = None
|
337 |
+
|
338 |
+
return att_feats, seq, att_masks, seq_mask
|
339 |
+
|
340 |
+
def _forward(self, fc_feats, att_feats, seq, att_masks=None):
|
341 |
+
if seq.ndim == 3: # B * seq_per_img * seq_len
|
342 |
+
seq = seq.reshape(-1, seq.shape[2])
|
343 |
+
att_feats, seq, att_masks, seq_mask = self._prepare_feature_forward(att_feats, att_masks, seq)
|
344 |
+
|
345 |
+
out = self.model(att_feats, seq, att_masks, seq_mask)
|
346 |
+
|
347 |
+
outputs = self.model.generator(out)
|
348 |
+
return outputs
|
349 |
+
# return torch.cat([_.unsqueeze(1) for _ in outputs], 1)
|
350 |
+
|
351 |
+
def core(self, it, fc_feats_ph, att_feats_ph, memory, state, mask):
|
352 |
+
"""
|
353 |
+
state = [ys.unsqueeze(0)]
|
354 |
+
"""
|
355 |
+
if len(state) == 0:
|
356 |
+
ys = it.unsqueeze(1)
|
357 |
+
else:
|
358 |
+
ys = torch.cat([state[0][0], it.unsqueeze(1)], dim=1)
|
359 |
+
out = self.model.decode(memory, mask,
|
360 |
+
ys,
|
361 |
+
subsequent_mask(ys.size(1))
|
362 |
+
.to(memory.device))
|
363 |
+
return out[:, -1], [ys.unsqueeze(0)]
|
captioning/models/__init__.py
ADDED
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import absolute_import
|
2 |
+
from __future__ import division
|
3 |
+
from __future__ import print_function
|
4 |
+
|
5 |
+
import os
|
6 |
+
import copy
|
7 |
+
|
8 |
+
import numpy as np
|
9 |
+
import torch
|
10 |
+
|
11 |
+
from .ShowTellModel import ShowTellModel
|
12 |
+
from .FCModel import FCModel
|
13 |
+
from .AttModel import *
|
14 |
+
from .TransformerModel import TransformerModel
|
15 |
+
from .cachedTransformer import TransformerModel as cachedTransformer
|
16 |
+
from .BertCapModel import BertCapModel
|
17 |
+
from .M2Transformer import M2TransformerModel
|
18 |
+
from .AoAModel import AoAModel
|
19 |
+
|
20 |
+
def setup(opt):
|
21 |
+
if opt.caption_model in ['fc', 'show_tell']:
|
22 |
+
print('Warning: %s model is mostly deprecated; many new features are not supported.' %opt.caption_model)
|
23 |
+
if opt.caption_model == 'fc':
|
24 |
+
print('Use newfc instead of fc')
|
25 |
+
if opt.caption_model == 'fc':
|
26 |
+
model = FCModel(opt)
|
27 |
+
elif opt.caption_model == 'language_model':
|
28 |
+
model = LMModel(opt)
|
29 |
+
elif opt.caption_model == 'newfc':
|
30 |
+
model = NewFCModel(opt)
|
31 |
+
elif opt.caption_model == 'show_tell':
|
32 |
+
model = ShowTellModel(opt)
|
33 |
+
# Att2in model in self-critical
|
34 |
+
elif opt.caption_model == 'att2in':
|
35 |
+
model = Att2inModel(opt)
|
36 |
+
# Att2in model with two-layer MLP img embedding and word embedding
|
37 |
+
elif opt.caption_model == 'att2in2':
|
38 |
+
model = Att2in2Model(opt)
|
39 |
+
elif opt.caption_model == 'att2all2':
|
40 |
+
print('Warning: this is not a correct implementation of the att2all model in the original paper.')
|
41 |
+
model = Att2all2Model(opt)
|
42 |
+
# Adaptive Attention model from Knowing when to look
|
43 |
+
elif opt.caption_model == 'adaatt':
|
44 |
+
model = AdaAttModel(opt)
|
45 |
+
# Adaptive Attention with maxout lstm
|
46 |
+
elif opt.caption_model == 'adaattmo':
|
47 |
+
model = AdaAttMOModel(opt)
|
48 |
+
# Top-down attention model
|
49 |
+
elif opt.caption_model in ['topdown', 'updown']:
|
50 |
+
model = UpDownModel(opt)
|
51 |
+
# StackAtt
|
52 |
+
elif opt.caption_model == 'stackatt':
|
53 |
+
model = StackAttModel(opt)
|
54 |
+
# DenseAtt
|
55 |
+
elif opt.caption_model == 'denseatt':
|
56 |
+
model = DenseAttModel(opt)
|
57 |
+
# Transformer
|
58 |
+
elif opt.caption_model == 'transformer':
|
59 |
+
if getattr(opt, 'cached_transformer', False):
|
60 |
+
model = cachedTransformer(opt)
|
61 |
+
else:
|
62 |
+
model = TransformerModel(opt)
|
63 |
+
# AoANet
|
64 |
+
elif opt.caption_model == 'aoa':
|
65 |
+
model = AoAModel(opt)
|
66 |
+
elif opt.caption_model == 'bert':
|
67 |
+
model = BertCapModel(opt)
|
68 |
+
elif opt.caption_model == 'm2transformer':
|
69 |
+
model = M2TransformerModel(opt)
|
70 |
+
else:
|
71 |
+
raise Exception("Caption model not supported: {}".format(opt.caption_model))
|
72 |
+
|
73 |
+
return model
|
captioning/models/cachedTransformer.py
ADDED
@@ -0,0 +1,420 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# This file contains Transformer network
|
2 |
+
# Most of the code is copied from http://nlp.seas.harvard.edu/2018/04/03/attention.html
|
3 |
+
|
4 |
+
# The cfg name correspondance:
|
5 |
+
# N=num_layers
|
6 |
+
# d_model=input_encoding_size
|
7 |
+
# d_ff=rnn_size
|
8 |
+
# h is always 8
|
9 |
+
|
10 |
+
from __future__ import absolute_import
|
11 |
+
from __future__ import division
|
12 |
+
from __future__ import print_function
|
13 |
+
|
14 |
+
import torch
|
15 |
+
import torch.nn as nn
|
16 |
+
import torch.nn.functional as F
|
17 |
+
from . import utils
|
18 |
+
|
19 |
+
import copy
|
20 |
+
import math
|
21 |
+
import numpy as np
|
22 |
+
|
23 |
+
from .CaptionModel import CaptionModel
|
24 |
+
from .AttModel import sort_pack_padded_sequence, pad_unsort_packed_sequence, pack_wrapper, AttModel
|
25 |
+
|
26 |
+
class EncoderDecoder(nn.Module):
|
27 |
+
"""
|
28 |
+
A standard Encoder-Decoder architecture. Base for this and many
|
29 |
+
other models.
|
30 |
+
"""
|
31 |
+
def __init__(self, encoder, decoder, src_embed, tgt_embed, generator):
|
32 |
+
super(EncoderDecoder, self).__init__()
|
33 |
+
self.encoder = encoder
|
34 |
+
self.decoder = decoder
|
35 |
+
self.src_embed = src_embed
|
36 |
+
self.tgt_embed = tgt_embed
|
37 |
+
self.generator = generator
|
38 |
+
|
39 |
+
def forward(self, src, tgt, src_mask, tgt_mask):
|
40 |
+
"Take in and process masked src and target sequences."
|
41 |
+
return self.decode(self.encode(src, src_mask), src_mask,
|
42 |
+
tgt, tgt_mask)
|
43 |
+
|
44 |
+
def encode(self, src, src_mask):
|
45 |
+
return self.encoder(self.src_embed(src), src_mask)
|
46 |
+
|
47 |
+
def decode(self, memory, src_mask, tgt, tgt_mask, past=None):
|
48 |
+
return self.decoder(self.tgt_embed(tgt), memory, src_mask, tgt_mask, past=past)
|
49 |
+
|
50 |
+
class Generator(nn.Module):
|
51 |
+
"Define standard linear + softmax generation step."
|
52 |
+
def __init__(self, d_model, vocab):
|
53 |
+
super(Generator, self).__init__()
|
54 |
+
self.proj = nn.Linear(d_model, vocab)
|
55 |
+
|
56 |
+
def forward(self, x):
|
57 |
+
return F.log_softmax(self.proj(x), dim=-1)
|
58 |
+
|
59 |
+
def clones(module, N):
|
60 |
+
"Produce N identical layers."
|
61 |
+
return nn.ModuleList([copy.deepcopy(module) for _ in range(N)])
|
62 |
+
|
63 |
+
class Encoder(nn.Module):
|
64 |
+
"Core encoder is a stack of N layers"
|
65 |
+
def __init__(self, layer, N):
|
66 |
+
super(Encoder, self).__init__()
|
67 |
+
self.layers = clones(layer, N)
|
68 |
+
self.norm = LayerNorm(layer.size)
|
69 |
+
|
70 |
+
def forward(self, x, mask):
|
71 |
+
"Pass the input (and mask) through each layer in turn."
|
72 |
+
for layer in self.layers:
|
73 |
+
x = layer(x, mask)
|
74 |
+
return self.norm(x)
|
75 |
+
|
76 |
+
class LayerNorm(nn.Module):
|
77 |
+
"Construct a layernorm module (See citation for details)."
|
78 |
+
def __init__(self, features, eps=1e-6):
|
79 |
+
super(LayerNorm, self).__init__()
|
80 |
+
self.a_2 = nn.Parameter(torch.ones(features))
|
81 |
+
self.b_2 = nn.Parameter(torch.zeros(features))
|
82 |
+
self.eps = eps
|
83 |
+
|
84 |
+
def forward(self, x):
|
85 |
+
mean = x.mean(-1, keepdim=True)
|
86 |
+
std = x.std(-1, keepdim=True)
|
87 |
+
return self.a_2 * (x - mean) / (std + self.eps) + self.b_2
|
88 |
+
|
89 |
+
class SublayerConnection(nn.Module):
|
90 |
+
"""
|
91 |
+
A residual connection followed by a layer norm.
|
92 |
+
Note for code simplicity the norm is first as opposed to last.
|
93 |
+
"""
|
94 |
+
def __init__(self, size, dropout):
|
95 |
+
super(SublayerConnection, self).__init__()
|
96 |
+
self.norm = LayerNorm(size)
|
97 |
+
self.dropout = nn.Dropout(dropout)
|
98 |
+
|
99 |
+
def forward(self, x, sublayer):
|
100 |
+
"Apply residual connection to any sublayer with the same size."
|
101 |
+
_x = sublayer(self.norm(x))
|
102 |
+
if type(_x) is tuple: # for multi-head attention that returns past
|
103 |
+
return x + self.dropout(_x[0]), _x[1]
|
104 |
+
return x + self.dropout(_x)
|
105 |
+
|
106 |
+
class EncoderLayer(nn.Module):
|
107 |
+
"Encoder is made up of self-attn and feed forward (defined below)"
|
108 |
+
def __init__(self, size, self_attn, feed_forward, dropout):
|
109 |
+
super(EncoderLayer, self).__init__()
|
110 |
+
self.self_attn = self_attn
|
111 |
+
self.feed_forward = feed_forward
|
112 |
+
self.sublayer = clones(SublayerConnection(size, dropout), 2)
|
113 |
+
self.size = size
|
114 |
+
|
115 |
+
def forward(self, x, mask):
|
116 |
+
"Follow Figure 1 (left) for connections."
|
117 |
+
x = self.sublayer[0](x, lambda x: self.self_attn(x, x, x, mask))
|
118 |
+
return self.sublayer[1](x, self.feed_forward)
|
119 |
+
|
120 |
+
class Decoder(nn.Module):
|
121 |
+
"Generic N layer decoder with masking."
|
122 |
+
def __init__(self, layer, N):
|
123 |
+
super(Decoder, self).__init__()
|
124 |
+
self.layers = clones(layer, N)
|
125 |
+
self.norm = LayerNorm(layer.size)
|
126 |
+
|
127 |
+
def forward(self, x, memory, src_mask, tgt_mask, past=None):
|
128 |
+
if past is not None:
|
129 |
+
present = [[], []]
|
130 |
+
x = x[:, -1:]
|
131 |
+
tgt_mask = tgt_mask[:, -1:] if tgt_mask is not None else None
|
132 |
+
past = list(zip(past[0].split(2, dim=0), past[1].split(2, dim=0)))
|
133 |
+
else:
|
134 |
+
past = [None] * len(self.layers)
|
135 |
+
for i, (layer, layer_past) in enumerate(zip(self.layers, past)):
|
136 |
+
x = layer(x, memory, src_mask, tgt_mask,
|
137 |
+
layer_past)
|
138 |
+
if layer_past is not None:
|
139 |
+
present[0].append(x[1][0])
|
140 |
+
present[1].append(x[1][1])
|
141 |
+
x = x[0]
|
142 |
+
if past[0] is None:
|
143 |
+
return self.norm(x)
|
144 |
+
else:
|
145 |
+
return self.norm(x), [torch.cat(present[0], 0), torch.cat(present[1], 0)]
|
146 |
+
|
147 |
+
|
148 |
+
class DecoderLayer(nn.Module):
|
149 |
+
"Decoder is made of self-attn, src-attn, and feed forward (defined below)"
|
150 |
+
def __init__(self, size, self_attn, src_attn, feed_forward, dropout):
|
151 |
+
super(DecoderLayer, self).__init__()
|
152 |
+
self.size = size
|
153 |
+
self.self_attn = self_attn
|
154 |
+
self.src_attn = src_attn
|
155 |
+
self.feed_forward = feed_forward
|
156 |
+
self.sublayer = clones(SublayerConnection(size, dropout), 3)
|
157 |
+
|
158 |
+
def forward(self, x, memory, src_mask, tgt_mask, layer_past=None):
|
159 |
+
"Follow Figure 1 (right) for connections."
|
160 |
+
m = memory
|
161 |
+
if layer_past is None:
|
162 |
+
x = self.sublayer[0](x, lambda x: self.self_attn(x, x, x, tgt_mask))
|
163 |
+
x = self.sublayer[1](x, lambda x: self.src_attn(x, m, m, src_mask))
|
164 |
+
return self.sublayer[2](x, self.feed_forward)
|
165 |
+
else:
|
166 |
+
present = [None, None]
|
167 |
+
x, present[0] = self.sublayer[0](x, lambda x: self.self_attn(x, x, x, tgt_mask, layer_past[0]))
|
168 |
+
x, present[1] = self.sublayer[1](x, lambda x: self.src_attn(x, m, m, src_mask, layer_past[1]))
|
169 |
+
return self.sublayer[2](x, self.feed_forward), present
|
170 |
+
|
171 |
+
def subsequent_mask(size):
|
172 |
+
"Mask out subsequent positions."
|
173 |
+
attn_shape = (1, size, size)
|
174 |
+
subsequent_mask = np.triu(np.ones(attn_shape), k=1).astype('uint8')
|
175 |
+
return torch.from_numpy(subsequent_mask) == 0
|
176 |
+
|
177 |
+
def attention(query, key, value, mask=None, dropout=None):
|
178 |
+
"Compute 'Scaled Dot Product Attention'"
|
179 |
+
d_k = query.size(-1)
|
180 |
+
scores = torch.matmul(query, key.transpose(-2, -1)) \
|
181 |
+
/ math.sqrt(d_k)
|
182 |
+
if mask is not None:
|
183 |
+
scores = scores.masked_fill(mask == 0, float('-inf'))
|
184 |
+
p_attn = F.softmax(scores, dim = -1)
|
185 |
+
if dropout is not None:
|
186 |
+
p_attn = dropout(p_attn)
|
187 |
+
return torch.matmul(p_attn, value), p_attn
|
188 |
+
|
189 |
+
class MultiHeadedAttention(nn.Module):
|
190 |
+
def __init__(self, h, d_model, dropout=0.1):
|
191 |
+
"Take in model size and number of heads."
|
192 |
+
super(MultiHeadedAttention, self).__init__()
|
193 |
+
assert d_model % h == 0
|
194 |
+
# We assume d_v always equals d_k
|
195 |
+
self.d_k = d_model // h
|
196 |
+
self.h = h
|
197 |
+
self.linears = clones(nn.Linear(d_model, d_model), 4)
|
198 |
+
self.attn = None
|
199 |
+
self.dropout = nn.Dropout(p=dropout)
|
200 |
+
|
201 |
+
def forward(self, query, key, value, mask=None, layer_past=None):
|
202 |
+
"Implements Figure 2"
|
203 |
+
if mask is not None:
|
204 |
+
# Same mask applied to all h heads.
|
205 |
+
mask = mask.unsqueeze(1)
|
206 |
+
nbatches = query.size(0)
|
207 |
+
|
208 |
+
# The past works differently here. For self attn, the query and key be updated incrementailly
|
209 |
+
# For src_attn the past is fixed.
|
210 |
+
|
211 |
+
# For src_attn, when the layer past is ready
|
212 |
+
if layer_past is not None and layer_past.shape[2] == key.shape[1] > 1: # suppose memory size always greater than 1
|
213 |
+
query = self.linears[0](query)
|
214 |
+
key, value = layer_past[0], layer_past[1]
|
215 |
+
present = torch.stack([key, value])
|
216 |
+
else:
|
217 |
+
# 1) Do all the linear projections in batch from d_model => h x d_k
|
218 |
+
query, key, value = \
|
219 |
+
[l(x) for l, x in zip(self.linears, (query, key, value))]
|
220 |
+
|
221 |
+
# self attn + past OR the first time step of src attn
|
222 |
+
if layer_past is not None and not (layer_past.shape[2] == key.shape[1] > 1):
|
223 |
+
past_key, past_value = layer_past[0], layer_past[1]
|
224 |
+
key = torch.cat((past_key, key), dim=1)
|
225 |
+
value = torch.cat((past_value, value), dim=1)
|
226 |
+
present = torch.stack([key, value])
|
227 |
+
|
228 |
+
query, key, value = \
|
229 |
+
[x.view(nbatches, -1, self.h, self.d_k).transpose(1, 2)
|
230 |
+
for x in [query, key, value]]
|
231 |
+
|
232 |
+
# 2) Apply attention on all the projected vectors in batch.
|
233 |
+
x, self.attn = attention(query, key, value, mask=mask,
|
234 |
+
dropout=self.dropout)
|
235 |
+
|
236 |
+
# 3) "Concat" using a view and apply a final linear.
|
237 |
+
x = x.transpose(1, 2).contiguous() \
|
238 |
+
.view(nbatches, -1, self.h * self.d_k)
|
239 |
+
if layer_past is not None:
|
240 |
+
return self.linears[-1](x), present
|
241 |
+
else:
|
242 |
+
return self.linears[-1](x)
|
243 |
+
|
244 |
+
class PositionwiseFeedForward(nn.Module):
|
245 |
+
"Implements FFN equation."
|
246 |
+
def __init__(self, d_model, d_ff, dropout=0.1):
|
247 |
+
super(PositionwiseFeedForward, self).__init__()
|
248 |
+
self.w_1 = nn.Linear(d_model, d_ff)
|
249 |
+
self.w_2 = nn.Linear(d_ff, d_model)
|
250 |
+
self.dropout = nn.Dropout(dropout)
|
251 |
+
|
252 |
+
def forward(self, x):
|
253 |
+
return self.w_2(self.dropout(F.relu(self.w_1(x))))
|
254 |
+
|
255 |
+
class Embeddings(nn.Module):
|
256 |
+
def __init__(self, d_model, vocab):
|
257 |
+
super(Embeddings, self).__init__()
|
258 |
+
self.lut = nn.Embedding(vocab, d_model)
|
259 |
+
self.d_model = d_model
|
260 |
+
|
261 |
+
def forward(self, x):
|
262 |
+
return self.lut(x) * math.sqrt(self.d_model)
|
263 |
+
|
264 |
+
class PositionalEncoding(nn.Module):
|
265 |
+
"Implement the PE function."
|
266 |
+
def __init__(self, d_model, dropout, max_len=5000):
|
267 |
+
super(PositionalEncoding, self).__init__()
|
268 |
+
self.dropout = nn.Dropout(p=dropout)
|
269 |
+
|
270 |
+
# Compute the positional encodings once in log space.
|
271 |
+
pe = torch.zeros(max_len, d_model)
|
272 |
+
position = torch.arange(0, max_len).unsqueeze(1).float()
|
273 |
+
div_term = torch.exp(torch.arange(0, d_model, 2).float() *
|
274 |
+
-(math.log(10000.0) / d_model))
|
275 |
+
pe[:, 0::2] = torch.sin(position * div_term)
|
276 |
+
pe[:, 1::2] = torch.cos(position * div_term)
|
277 |
+
pe = pe.unsqueeze(0)
|
278 |
+
self.register_buffer('pe', pe)
|
279 |
+
|
280 |
+
def forward(self, x):
|
281 |
+
x = x + self.pe[:, :x.size(1)]
|
282 |
+
return self.dropout(x)
|
283 |
+
|
284 |
+
class TransformerModel(AttModel):
|
285 |
+
|
286 |
+
def make_model(self, src_vocab, tgt_vocab, N_enc=6, N_dec=6,
|
287 |
+
d_model=512, d_ff=2048, h=8, dropout=0.1):
|
288 |
+
"Helper: Construct a model from hyperparameters."
|
289 |
+
c = copy.deepcopy
|
290 |
+
attn = MultiHeadedAttention(h, d_model, dropout)
|
291 |
+
ff = PositionwiseFeedForward(d_model, d_ff, dropout)
|
292 |
+
position = PositionalEncoding(d_model, dropout)
|
293 |
+
model = EncoderDecoder(
|
294 |
+
Encoder(EncoderLayer(d_model, c(attn), c(ff), dropout), N_enc),
|
295 |
+
Decoder(DecoderLayer(d_model, c(attn), c(attn),
|
296 |
+
c(ff), dropout), N_dec),
|
297 |
+
lambda x:x, # nn.Sequential(Embeddings(d_model, src_vocab), c(position)),
|
298 |
+
nn.Sequential(Embeddings(d_model, tgt_vocab), c(position)),
|
299 |
+
Generator(d_model, tgt_vocab))
|
300 |
+
|
301 |
+
# This was important from their code.
|
302 |
+
# Initialize parameters with Glorot / fan_avg.
|
303 |
+
for p in model.parameters():
|
304 |
+
if p.dim() > 1:
|
305 |
+
nn.init.xavier_uniform_(p)
|
306 |
+
return model
|
307 |
+
|
308 |
+
def __init__(self, opt):
|
309 |
+
super(TransformerModel, self).__init__(opt)
|
310 |
+
self.opt = opt
|
311 |
+
# self.config = yaml.load(open(opt.config_file))
|
312 |
+
|
313 |
+
self.N_enc = getattr(opt, 'N_enc', opt.num_layers)
|
314 |
+
self.N_dec = getattr(opt, 'N_dec', opt.num_layers)
|
315 |
+
self.d_model = getattr(opt, 'd_model', opt.input_encoding_size)
|
316 |
+
self.d_ff = getattr(opt, 'd_ff', opt.rnn_size)
|
317 |
+
self.h = getattr(opt, 'num_att_heads', 8)
|
318 |
+
self.dropout = getattr(opt, 'dropout', 0.1)
|
319 |
+
|
320 |
+
delattr(self, 'att_embed')
|
321 |
+
self.att_embed = nn.Sequential(*(
|
322 |
+
((nn.BatchNorm1d(self.att_feat_size),) if self.use_bn else ())+
|
323 |
+
(nn.Linear(self.att_feat_size, self.d_model),
|
324 |
+
nn.ReLU(),
|
325 |
+
nn.Dropout(self.drop_prob_lm))+
|
326 |
+
((nn.BatchNorm1d(self.d_model),) if self.use_bn==2 else ())))
|
327 |
+
|
328 |
+
delattr(self, 'embed')
|
329 |
+
self.embed = lambda x : x
|
330 |
+
delattr(self, 'fc_embed')
|
331 |
+
self.fc_embed = lambda x : x
|
332 |
+
delattr(self, 'logit')
|
333 |
+
del self.ctx2att
|
334 |
+
|
335 |
+
tgt_vocab = self.vocab_size + 1
|
336 |
+
|
337 |
+
|
338 |
+
self.model = self.make_model(0, tgt_vocab,
|
339 |
+
N_enc=self.N_enc,
|
340 |
+
N_dec=self.N_dec,
|
341 |
+
d_model=self.d_model,
|
342 |
+
d_ff=self.d_ff,
|
343 |
+
h=self.h,
|
344 |
+
dropout=self.dropout)
|
345 |
+
|
346 |
+
def logit(self, x): # unsafe way
|
347 |
+
return self.model.generator.proj(x)
|
348 |
+
|
349 |
+
def init_hidden(self, bsz):
|
350 |
+
return []
|
351 |
+
|
352 |
+
def _prepare_feature(self, fc_feats, att_feats, att_masks):
|
353 |
+
|
354 |
+
att_feats, seq, att_masks, seq_mask = self._prepare_feature_forward(att_feats, att_masks)
|
355 |
+
memory = self.model.encode(att_feats, att_masks)
|
356 |
+
|
357 |
+
return fc_feats[...,:0], att_feats[...,:0], memory, att_masks
|
358 |
+
|
359 |
+
def _prepare_feature_forward(self, att_feats, att_masks=None, seq=None):
|
360 |
+
att_feats, att_masks = self.clip_att(att_feats, att_masks)
|
361 |
+
|
362 |
+
att_feats = pack_wrapper(self.att_embed, att_feats, att_masks)
|
363 |
+
|
364 |
+
if att_masks is None:
|
365 |
+
att_masks = att_feats.new_ones(att_feats.shape[:2], dtype=torch.long)
|
366 |
+
att_masks = att_masks.unsqueeze(-2)
|
367 |
+
|
368 |
+
if seq is not None:
|
369 |
+
# crop the last one
|
370 |
+
# seq = seq[:,:-1]
|
371 |
+
seq_mask = (seq.data != self.eos_idx) & (seq.data != self.pad_idx)
|
372 |
+
seq_mask[:,0] = 1 # bos
|
373 |
+
|
374 |
+
seq_mask = seq_mask.unsqueeze(-2)
|
375 |
+
seq_mask = seq_mask & subsequent_mask(seq.size(-1)).to(seq_mask)
|
376 |
+
|
377 |
+
seq_per_img = seq.shape[0] // att_feats.shape[0]
|
378 |
+
if seq_per_img > 1:
|
379 |
+
att_feats, att_masks = utils.repeat_tensors(seq_per_img,
|
380 |
+
[att_feats, att_masks]
|
381 |
+
)
|
382 |
+
else:
|
383 |
+
seq_mask = None
|
384 |
+
|
385 |
+
return att_feats, seq, att_masks, seq_mask
|
386 |
+
|
387 |
+
def _forward(self, fc_feats, att_feats, seq, att_masks=None):
|
388 |
+
if seq.ndim == 3: # B * seq_per_img * seq_len
|
389 |
+
seq = seq.reshape(-1, seq.shape[2])
|
390 |
+
att_feats, seq, att_masks, seq_mask = self._prepare_feature_forward(att_feats, att_masks, seq)
|
391 |
+
|
392 |
+
out = self.model(att_feats, seq, att_masks, seq_mask)
|
393 |
+
|
394 |
+
outputs = self.model.generator(out)
|
395 |
+
return outputs
|
396 |
+
# return torch.cat([_.unsqueeze(1) for _ in outputs], 1)
|
397 |
+
|
398 |
+
def core(self, it, fc_feats_ph, att_feats_ph, memory, state, mask):
|
399 |
+
"""
|
400 |
+
state is the precomputed key/value. N_dec x seq_len x d_model
|
401 |
+
Note: due to the layer norm, it's not equivalant to stateless,
|
402 |
+
but it seems behaving similar
|
403 |
+
"""
|
404 |
+
# state is tokens + past
|
405 |
+
if len(state) == 0:
|
406 |
+
ys = it.unsqueeze(1)
|
407 |
+
# basically empty state, just to let it know to return past
|
408 |
+
# The second dim has to be batch_size, for beam search purpose
|
409 |
+
past = [fc_feats_ph.new_zeros(self.N_dec * 2, fc_feats_ph.shape[0], 0, self.d_model), # self
|
410 |
+
fc_feats_ph.new_zeros(self.N_dec * 2, fc_feats_ph.shape[0], 0, self.d_model)] # src
|
411 |
+
# 2 for self attn, 2 for src attn
|
412 |
+
else:
|
413 |
+
ys = torch.cat([state[0][0], it.unsqueeze(1)], dim=1)
|
414 |
+
past = state[1:]
|
415 |
+
out, past = self.model.decode(memory, mask,
|
416 |
+
ys, # We still feed the full past words, because we need it for position embedding to know the position id
|
417 |
+
subsequent_mask(ys.size(1))
|
418 |
+
.to(memory.device),
|
419 |
+
past=past)
|
420 |
+
return out[:, -1], [ys.unsqueeze(0)] + past
|
captioning/models/utils.py
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
def repeat_tensors(n, x):
|
4 |
+
"""
|
5 |
+
For a tensor of size Bx..., we repeat it n times, and make it Bnx...
|
6 |
+
For collections, do nested repeat
|
7 |
+
"""
|
8 |
+
if torch.is_tensor(x):
|
9 |
+
x = x.unsqueeze(1) # Bx1x...
|
10 |
+
x = x.expand(-1, n, *([-1]*len(x.shape[2:]))) # Bxnx...
|
11 |
+
x = x.reshape(x.shape[0]*n, *x.shape[2:]) # Bnx...
|
12 |
+
elif type(x) is list or type(x) is tuple:
|
13 |
+
x = [repeat_tensors(n, _) for _ in x]
|
14 |
+
return x
|
15 |
+
|
16 |
+
|
17 |
+
def split_tensors(n, x):
|
18 |
+
if torch.is_tensor(x):
|
19 |
+
assert x.shape[0] % n == 0
|
20 |
+
x = x.reshape(x.shape[0] // n, n, *x.shape[1:]).unbind(1)
|
21 |
+
elif type(x) is list or type(x) is tuple:
|
22 |
+
x = [split_tensors(n, _) for _ in x]
|
23 |
+
elif x is None:
|
24 |
+
x = [None] * n
|
25 |
+
return x
|
captioning/modules/loss_wrapper.py
ADDED
@@ -0,0 +1,127 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from . import losses
|
3 |
+
from ..utils.rewards import init_scorer, get_self_critical_reward, get_self_critical_clipscore_reward
|
4 |
+
from ..utils.clipscore import CLIPScore
|
5 |
+
import numpy as np
|
6 |
+
|
7 |
+
class LossWrapper(torch.nn.Module):
|
8 |
+
def __init__(self, model, opt):
|
9 |
+
super(LossWrapper, self).__init__()
|
10 |
+
self.opt = opt
|
11 |
+
self.model = model
|
12 |
+
if opt.label_smoothing > 0:
|
13 |
+
self.crit = losses.LabelSmoothing(smoothing=opt.label_smoothing)
|
14 |
+
else:
|
15 |
+
self.crit = losses.LanguageModelCriterion()
|
16 |
+
self.rl_crit = losses.RewardCriterion()
|
17 |
+
self.struc_crit = losses.StructureLosses(opt)
|
18 |
+
|
19 |
+
self.clipscore_model = None
|
20 |
+
if self.opt.use_clipscore:
|
21 |
+
use_grammar = getattr(self.opt, 'use_grammar', False)
|
22 |
+
joint_out = getattr(self.opt, 'joint_out', False)
|
23 |
+
self.clipscore_model = CLIPScore(
|
24 |
+
mode=opt.clipscore_mode,
|
25 |
+
use_grammar=use_grammar,
|
26 |
+
joint_out=joint_out,
|
27 |
+
)
|
28 |
+
for p in self.clipscore_model.parameters():
|
29 |
+
p.requires_grad = False
|
30 |
+
|
31 |
+
if use_grammar:
|
32 |
+
state_dict = torch.load(self.opt.clip_load_path, map_location='cpu')
|
33 |
+
self.clipscore_model.load_state_dict(state_dict['state_dict'])
|
34 |
+
|
35 |
+
def forward(self, fc_feats, att_feats, labels, masks, att_masks, gts, gt_indices,
|
36 |
+
sc_flag, struc_flag, clip_vis_feats=None):
|
37 |
+
opt = self.opt
|
38 |
+
|
39 |
+
out = {}
|
40 |
+
if struc_flag:
|
41 |
+
if opt.structure_loss_weight < 1:
|
42 |
+
lm_loss = self.crit(self.model(fc_feats, att_feats, labels[..., :-1], att_masks), labels[..., 1:], masks[..., 1:])
|
43 |
+
else:
|
44 |
+
lm_loss = torch.tensor(0).type_as(fc_feats)
|
45 |
+
if opt.structure_loss_weight > 0:
|
46 |
+
gen_result, sample_logprobs = self.model(fc_feats, att_feats, att_masks,
|
47 |
+
opt={'sample_method':opt.train_sample_method,
|
48 |
+
'beam_size':opt.train_beam_size,
|
49 |
+
'output_logsoftmax': opt.struc_use_logsoftmax or opt.structure_loss_type == 'softmax_margin'\
|
50 |
+
or not 'margin' in opt.structure_loss_type,
|
51 |
+
'sample_n': opt.train_sample_n},
|
52 |
+
mode='sample')
|
53 |
+
gts = [gts[_] for _ in gt_indices.tolist()]
|
54 |
+
struc_loss = self.struc_crit(sample_logprobs, gen_result, gts)
|
55 |
+
else:
|
56 |
+
struc_loss = {'loss': torch.tensor(0).type_as(fc_feats),
|
57 |
+
'reward': torch.tensor(0).type_as(fc_feats)}
|
58 |
+
loss = (1-opt.structure_loss_weight) * lm_loss + opt.structure_loss_weight * struc_loss['loss']
|
59 |
+
out['lm_loss'] = lm_loss
|
60 |
+
out['struc_loss'] = struc_loss['loss']
|
61 |
+
out['reward'] = struc_loss['reward']
|
62 |
+
elif not sc_flag:
|
63 |
+
loss = self.crit(self.model(fc_feats, att_feats, labels[..., :-1], att_masks), labels[..., 1:], masks[..., 1:])
|
64 |
+
else:
|
65 |
+
self.model.eval()
|
66 |
+
with torch.no_grad():
|
67 |
+
greedy_res, _ = self.model(fc_feats, att_feats, att_masks,
|
68 |
+
mode='sample',
|
69 |
+
opt={'sample_method': opt.sc_sample_method,
|
70 |
+
'beam_size': opt.sc_beam_size})
|
71 |
+
self.model.train()
|
72 |
+
gen_result, sample_logprobs = self.model(fc_feats, att_feats, att_masks,
|
73 |
+
opt={'sample_method':opt.train_sample_method,
|
74 |
+
'beam_size':opt.train_beam_size,
|
75 |
+
'sample_n': opt.train_sample_n},
|
76 |
+
mode='sample')
|
77 |
+
gts = [gts[_] for _ in gt_indices.tolist()]
|
78 |
+
|
79 |
+
if getattr(self.opt, 'use_multi_rewards', False):
|
80 |
+
assert self.opt.use_clipscore
|
81 |
+
clipscore_reward_normalized, clipscore_unnormalized_mean, grammar_rewards = get_self_critical_clipscore_reward(
|
82 |
+
greedy_res, gts, gen_result, self.opt, self.clipscore_model, clip_vis_feats, self.model.vocab)
|
83 |
+
|
84 |
+
if self.opt.clipscore_mode == 'clip_s':
|
85 |
+
out['CLIP-S'] = clipscore_unnormalized_mean
|
86 |
+
elif self.opt.clipscore_mode == 'refclip_s':
|
87 |
+
out['RefCLIP-S'] = clipscore_unnormalized_mean
|
88 |
+
|
89 |
+
if getattr(self.opt, 'use_grammar', False):
|
90 |
+
out['grammar_reward'] = grammar_rewards.mean()
|
91 |
+
|
92 |
+
reward = clipscore_reward_normalized + grammar_rewards
|
93 |
+
|
94 |
+
|
95 |
+
else:
|
96 |
+
assert grammar_rewards is None
|
97 |
+
|
98 |
+
cider_reward_normalized, cider_unnormalized_mean = get_self_critical_reward(
|
99 |
+
greedy_res, gts, gen_result, self.opt)
|
100 |
+
out['CIDEr'] = cider_unnormalized_mean
|
101 |
+
if isinstance(cider_reward_normalized, np.ndarray):
|
102 |
+
cider_reward_normalized = torch.from_numpy(cider_reward_normalized).to(clipscore_reward_normalized.device)
|
103 |
+
|
104 |
+
reward = clipscore_reward_normalized + cider_reward_normalized
|
105 |
+
else:
|
106 |
+
if self.opt.use_clipscore:
|
107 |
+
clipscore_reward_normalized, clipscore_unnormalized_mean, _ = get_self_critical_clipscore_reward(
|
108 |
+
greedy_res, gts, gen_result, self.opt, self.clipscore_model, clip_vis_feats, self.model.vocab)
|
109 |
+
if self.opt.clipscore_mode == 'clip_s':
|
110 |
+
out['CLIP-S'] = clipscore_unnormalized_mean
|
111 |
+
elif self.opt.clipscore_mode == 'refclip_s':
|
112 |
+
out['RefCLIP-S'] = clipscore_unnormalized_mean
|
113 |
+
reward = clipscore_reward_normalized
|
114 |
+
else:
|
115 |
+
cider_reward_normalized, cider_unnormalized_mean = get_self_critical_reward(
|
116 |
+
greedy_res, gts, gen_result, self.opt)
|
117 |
+
out['CIDEr'] = cider_unnormalized_mean
|
118 |
+
reward = cider_reward_normalized
|
119 |
+
|
120 |
+
if isinstance(reward, np.ndarray):
|
121 |
+
reward = torch.from_numpy(reward)
|
122 |
+
reward = reward.to(sample_logprobs)
|
123 |
+
loss = self.rl_crit(sample_logprobs, gen_result.data, reward)
|
124 |
+
out['reward'] = reward[:,0].mean()
|
125 |
+
out['loss'] = loss
|
126 |
+
return out
|
127 |
+
|
captioning/modules/losses.py
ADDED
@@ -0,0 +1,218 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
from ..utils.rewards import get_scores, get_self_cider_scores
|
4 |
+
|
5 |
+
class RewardCriterion(nn.Module):
|
6 |
+
def __init__(self):
|
7 |
+
super(RewardCriterion, self).__init__()
|
8 |
+
|
9 |
+
def forward(self, input, seq, reward):
|
10 |
+
input = input.gather(2, seq.unsqueeze(2)).squeeze(2)
|
11 |
+
|
12 |
+
input = input.reshape(-1)
|
13 |
+
reward = reward.reshape(-1)
|
14 |
+
mask = (seq>0).to(input)
|
15 |
+
mask = torch.cat([mask.new(mask.size(0), 1).fill_(1), mask[:, :-1]], 1).reshape(-1)
|
16 |
+
output = - input * reward * mask
|
17 |
+
output = torch.sum(output) / torch.sum(mask)
|
18 |
+
|
19 |
+
return output
|
20 |
+
|
21 |
+
class StructureLosses(nn.Module):
|
22 |
+
"""
|
23 |
+
This loss is inspired by Classical Structured Prediction Losses for Sequence to Sequence Learning (Edunov et al., 2018).
|
24 |
+
"""
|
25 |
+
def __init__(self, opt):
|
26 |
+
super(StructureLosses, self).__init__()
|
27 |
+
self.opt = opt
|
28 |
+
self.loss_type = opt.structure_loss_type
|
29 |
+
|
30 |
+
def forward(self, input, seq, data_gts):
|
31 |
+
"""
|
32 |
+
Input is either logits or log softmax
|
33 |
+
"""
|
34 |
+
out = {}
|
35 |
+
|
36 |
+
batch_size = input.size(0)# batch_size = sample_size * seq_per_img
|
37 |
+
seq_per_img = batch_size // len(data_gts)
|
38 |
+
|
39 |
+
assert seq_per_img == self.opt.train_sample_n, seq_per_img
|
40 |
+
|
41 |
+
mask = (seq>0).to(input)
|
42 |
+
mask = torch.cat([mask.new_full((mask.size(0), 1), 1), mask[:, :-1]], 1)
|
43 |
+
|
44 |
+
scores = get_scores(data_gts, seq, self.opt)
|
45 |
+
scores = torch.from_numpy(scores).type_as(input).view(-1, seq_per_img)
|
46 |
+
out['reward'] = scores #.mean()
|
47 |
+
if self.opt.entropy_reward_weight > 0:
|
48 |
+
entropy = - (F.softmax(input, dim=2) * F.log_softmax(input, dim=2)).sum(2).data
|
49 |
+
entropy = (entropy * mask).sum(1) / mask.sum(1)
|
50 |
+
print('entropy', entropy.mean().item())
|
51 |
+
scores = scores + self.opt.entropy_reward_weight * entropy.view(-1, seq_per_img)
|
52 |
+
# rescale cost to [0,1]
|
53 |
+
costs = - scores
|
54 |
+
if self.loss_type == 'risk' or self.loss_type == 'softmax_margin':
|
55 |
+
costs = costs - costs.min(1, keepdim=True)[0]
|
56 |
+
costs = costs / costs.max(1, keepdim=True)[0]
|
57 |
+
# in principle
|
58 |
+
# Only risk need such rescale
|
59 |
+
# margin should be alright; Let's try.
|
60 |
+
|
61 |
+
# Gather input: BxTxD -> BxT
|
62 |
+
input = input.gather(2, seq.unsqueeze(2)).squeeze(2)
|
63 |
+
|
64 |
+
if self.loss_type == 'seqnll':
|
65 |
+
# input is logsoftmax
|
66 |
+
input = input * mask
|
67 |
+
input = input.sum(1) / mask.sum(1)
|
68 |
+
input = input.view(-1, seq_per_img)
|
69 |
+
|
70 |
+
target = costs.min(1)[1]
|
71 |
+
output = F.cross_entropy(input, target)
|
72 |
+
elif self.loss_type == 'risk':
|
73 |
+
# input is logsoftmax
|
74 |
+
input = input * mask
|
75 |
+
input = input.sum(1)
|
76 |
+
input = input.view(-1, seq_per_img)
|
77 |
+
|
78 |
+
output = (F.softmax(input.exp()) * costs).sum(1).mean()
|
79 |
+
|
80 |
+
# test
|
81 |
+
# avg_scores = input
|
82 |
+
# probs = F.softmax(avg_scores.exp_())
|
83 |
+
# loss = (probs * costs.type_as(probs)).sum() / input.size(0)
|
84 |
+
# print(output.item(), loss.item())
|
85 |
+
|
86 |
+
elif self.loss_type == 'max_margin':
|
87 |
+
# input is logits
|
88 |
+
input = input * mask
|
89 |
+
input = input.sum(1) / mask.sum(1)
|
90 |
+
input = input.view(-1, seq_per_img)
|
91 |
+
_, __ = costs.min(1, keepdim=True)
|
92 |
+
costs_star = _
|
93 |
+
input_star = input.gather(1, __)
|
94 |
+
output = F.relu(costs - costs_star - input_star + input).max(1)[0] / 2
|
95 |
+
output = output.mean()
|
96 |
+
|
97 |
+
# sanity test
|
98 |
+
# avg_scores = input + costs
|
99 |
+
# scores_with_high_target = avg_scores.clone()
|
100 |
+
# scores_with_high_target.scatter_(1, costs.min(1)[1].view(-1, 1), 1e10)
|
101 |
+
|
102 |
+
# target_and_offender_index = scores_with_high_target.sort(1, True)[1][:, 0:2]
|
103 |
+
# avg_scores = avg_scores.gather(1, target_and_offender_index)
|
104 |
+
# target_index = avg_scores.new_zeros(avg_scores.size(0), dtype=torch.long)
|
105 |
+
# loss = F.multi_margin_loss(avg_scores, target_index, size_average=True, margin=0)
|
106 |
+
# print(loss.item() * 2, output.item())
|
107 |
+
|
108 |
+
elif self.loss_type == 'multi_margin':
|
109 |
+
# input is logits
|
110 |
+
input = input * mask
|
111 |
+
input = input.sum(1) / mask.sum(1)
|
112 |
+
input = input.view(-1, seq_per_img)
|
113 |
+
_, __ = costs.min(1, keepdim=True)
|
114 |
+
costs_star = _
|
115 |
+
input_star = input.gather(1, __)
|
116 |
+
output = F.relu(costs - costs_star - input_star + input)
|
117 |
+
output = output.mean()
|
118 |
+
|
119 |
+
# sanity test
|
120 |
+
# avg_scores = input + costs
|
121 |
+
# loss = F.multi_margin_loss(avg_scores, costs.min(1)[1], margin=0)
|
122 |
+
# print(output, loss)
|
123 |
+
|
124 |
+
elif self.loss_type == 'softmax_margin':
|
125 |
+
# input is logsoftmax
|
126 |
+
input = input * mask
|
127 |
+
input = input.sum(1) / mask.sum(1)
|
128 |
+
input = input.view(-1, seq_per_img)
|
129 |
+
|
130 |
+
input = input + costs
|
131 |
+
target = costs.min(1)[1]
|
132 |
+
output = F.cross_entropy(input, target)
|
133 |
+
|
134 |
+
elif self.loss_type == 'real_softmax_margin':
|
135 |
+
# input is logits
|
136 |
+
# This is what originally defined in Kevin's paper
|
137 |
+
# The result should be equivalent to softmax_margin
|
138 |
+
input = input * mask
|
139 |
+
input = input.sum(1) / mask.sum(1)
|
140 |
+
input = input.view(-1, seq_per_img)
|
141 |
+
|
142 |
+
input = input + costs
|
143 |
+
target = costs.min(1)[1]
|
144 |
+
output = F.cross_entropy(input, target)
|
145 |
+
|
146 |
+
elif self.loss_type == 'new_self_critical':
|
147 |
+
"""
|
148 |
+
A different self critical
|
149 |
+
Self critical uses greedy decoding score as baseline;
|
150 |
+
This setting uses the average score of the rest samples as baseline
|
151 |
+
(suppose c1...cn n samples, reward1 = score1 - 1/(n-1)(score2+..+scoren) )
|
152 |
+
"""
|
153 |
+
baseline = (scores.sum(1, keepdim=True) - scores) / (scores.shape[1] - 1)
|
154 |
+
scores = scores - baseline
|
155 |
+
# self cider used as reward to promote diversity (not working that much in this way)
|
156 |
+
if getattr(self.opt, 'self_cider_reward_weight', 0) > 0:
|
157 |
+
_scores = get_self_cider_scores(data_gts, seq, self.opt)
|
158 |
+
_scores = torch.from_numpy(_scores).type_as(scores).view(-1, 1)
|
159 |
+
_scores = _scores.expand_as(scores - 1)
|
160 |
+
scores += self.opt.self_cider_reward_weight * _scores
|
161 |
+
output = - input * mask * scores.view(-1, 1)
|
162 |
+
output = torch.sum(output) / torch.sum(mask)
|
163 |
+
|
164 |
+
out['loss'] = output
|
165 |
+
return out
|
166 |
+
|
167 |
+
class LanguageModelCriterion(nn.Module):
|
168 |
+
def __init__(self):
|
169 |
+
super(LanguageModelCriterion, self).__init__()
|
170 |
+
|
171 |
+
def forward(self, input, target, mask):
|
172 |
+
if target.ndim == 3:
|
173 |
+
target = target.reshape(-1, target.shape[2])
|
174 |
+
mask = mask.reshape(-1, mask.shape[2])
|
175 |
+
# truncate to the same size
|
176 |
+
target = target[:, :input.size(1)]
|
177 |
+
mask = mask[:, :input.size(1)].to(input)
|
178 |
+
|
179 |
+
output = -input.gather(2, target.unsqueeze(2)).squeeze(2) * mask
|
180 |
+
# Average over each token
|
181 |
+
output = torch.sum(output) / torch.sum(mask)
|
182 |
+
|
183 |
+
return output
|
184 |
+
|
185 |
+
class LabelSmoothing(nn.Module):
|
186 |
+
"Implement label smoothing."
|
187 |
+
def __init__(self, size=0, padding_idx=0, smoothing=0.0):
|
188 |
+
super(LabelSmoothing, self).__init__()
|
189 |
+
self.criterion = nn.KLDivLoss(size_average=False, reduce=False)
|
190 |
+
# self.padding_idx = padding_idx
|
191 |
+
self.confidence = 1.0 - smoothing
|
192 |
+
self.smoothing = smoothing
|
193 |
+
# self.size = size
|
194 |
+
self.true_dist = None
|
195 |
+
|
196 |
+
def forward(self, input, target, mask):
|
197 |
+
if target.ndim == 3:
|
198 |
+
target = target.reshape(-1, target.shape[2])
|
199 |
+
mask = mask.reshape(-1, mask.shape[2])
|
200 |
+
# truncate to the same size
|
201 |
+
target = target[:, :input.size(1)]
|
202 |
+
mask = mask[:, :input.size(1)]
|
203 |
+
|
204 |
+
input = input.reshape(-1, input.size(-1))
|
205 |
+
target = target.reshape(-1)
|
206 |
+
mask = mask.reshape(-1).to(input)
|
207 |
+
|
208 |
+
# assert x.size(1) == self.size
|
209 |
+
self.size = input.size(1)
|
210 |
+
# true_dist = x.data.clone()
|
211 |
+
true_dist = input.data.clone()
|
212 |
+
# true_dist.fill_(self.smoothing / (self.size - 2))
|
213 |
+
true_dist.fill_(self.smoothing / (self.size - 1))
|
214 |
+
true_dist.scatter_(1, target.data.unsqueeze(1), self.confidence)
|
215 |
+
# true_dist[:, self.padding_idx] = 0
|
216 |
+
# mask = torch.nonzero(target.data == self.padding_idx)
|
217 |
+
# self.true_dist = true_dist
|
218 |
+
return (self.criterion(input, true_dist).sum(1) * mask).sum() / mask.sum()
|
captioning/utils/__init__.py
ADDED
File without changes
|
captioning/utils/clipscore.py
ADDED
@@ -0,0 +1,396 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import CLIPModel, CLIPTokenizer
|
2 |
+
import os
|
3 |
+
import json
|
4 |
+
import argparse
|
5 |
+
from random import shuffle, seed
|
6 |
+
import string
|
7 |
+
# non-standard dependencies:
|
8 |
+
import h5py
|
9 |
+
from six.moves import cPickle
|
10 |
+
import numpy as np
|
11 |
+
import torch
|
12 |
+
import torchvision.models as models
|
13 |
+
import skimage.io
|
14 |
+
|
15 |
+
from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
|
16 |
+
from PIL import Image
|
17 |
+
from torch import nn
|
18 |
+
|
19 |
+
|
20 |
+
class CLIPScore(nn.Module):
|
21 |
+
def __init__(self, clipscore_w=2.5, image_size=224, mode='clip_s', use_grammar=False, joint_out=False):
|
22 |
+
super(CLIPScore, self).__init__()
|
23 |
+
# from transformers import CLIPModel, CLIPTokenizer
|
24 |
+
self.clip_model = CLIPModel.from_pretrained(
|
25 |
+
'openai/clip-vit-base-patch32')
|
26 |
+
self.tokenizer = CLIPTokenizer.from_pretrained(
|
27 |
+
'openai/clip-vit-base-patch32')
|
28 |
+
|
29 |
+
self.clip_model.eval()
|
30 |
+
|
31 |
+
self.clipscore_w = clipscore_w
|
32 |
+
|
33 |
+
self.image_transform = self._transform(image_size)
|
34 |
+
|
35 |
+
self.mode = mode
|
36 |
+
assert mode in ['clip_s', 'refclip_s']
|
37 |
+
|
38 |
+
self.use_grammar = use_grammar
|
39 |
+
self.joint_out = joint_out
|
40 |
+
|
41 |
+
if self.use_grammar and joint_out is False:
|
42 |
+
self.grammar_score_head = nn.Sequential(
|
43 |
+
nn.Linear(self.clip_model.text_embed_dim, self.clip_model.projection_dim, bias=False),
|
44 |
+
nn.ReLU(),
|
45 |
+
nn.Linear(self.clip_model.projection_dim, 2, bias=False)
|
46 |
+
)
|
47 |
+
|
48 |
+
def _transform(self, n_px):
|
49 |
+
return Compose([
|
50 |
+
Resize(n_px, interpolation=Image.BICUBIC),
|
51 |
+
CenterCrop(n_px),
|
52 |
+
lambda image: image.convert("RGB"),
|
53 |
+
ToTensor(),
|
54 |
+
Normalize((0.48145466, 0.4578275, 0.40821073),
|
55 |
+
(0.26862954, 0.26130258, 0.27577711)),
|
56 |
+
])
|
57 |
+
|
58 |
+
def load_image(self, image_path):
|
59 |
+
image = Image.open(image_path)
|
60 |
+
return image
|
61 |
+
|
62 |
+
# @torch.no_grad()
|
63 |
+
def image_extract(self, image):
|
64 |
+
if isinstance(image, str):
|
65 |
+
image = self.load_image(image)
|
66 |
+
if not isinstance(image, torch.Tensor):
|
67 |
+
image = self.image_transform(image)
|
68 |
+
|
69 |
+
img_tensor = image.view(-1, 3, 224, 224)
|
70 |
+
device = next(self.clip_model.parameters()).device
|
71 |
+
img_tensor = img_tensor.to(device)
|
72 |
+
|
73 |
+
clip_model = self.clip_model
|
74 |
+
|
75 |
+
img_feat = clip_model.vision_model(img_tensor).pooler_output
|
76 |
+
img_feat = clip_model.visual_projection(img_feat)
|
77 |
+
img_feat = img_feat / img_feat.norm(dim=-1, keepdim=True)
|
78 |
+
|
79 |
+
return img_feat
|
80 |
+
|
81 |
+
# @torch.no_grad()
|
82 |
+
def text_extract(self, text, prompt="A photo depicts", proj_norm=True):
|
83 |
+
if isinstance(text, str):
|
84 |
+
text_batch = [" ".join([prompt, text])]
|
85 |
+
elif isinstance(text, list):
|
86 |
+
text_batch = [" ".join([prompt, txt]) for txt in text]
|
87 |
+
|
88 |
+
if isinstance(text, tuple) and isinstance(text[0], torch.Tensor):
|
89 |
+
input_ids, attention_mask = text
|
90 |
+
else:
|
91 |
+
input_text = text_batch
|
92 |
+
|
93 |
+
tokenized = self.tokenizer(
|
94 |
+
input_text, return_tensors='pt', padding=True, truncation=True)
|
95 |
+
|
96 |
+
input_ids = tokenized.input_ids
|
97 |
+
attention_mask = tokenized.attention_mask
|
98 |
+
|
99 |
+
clip_model = self.clip_model
|
100 |
+
device = next(self.clip_model.parameters()).device
|
101 |
+
input_ids = input_ids.to(device)
|
102 |
+
attention_mask = attention_mask.to(device)
|
103 |
+
|
104 |
+
text_feat = clip_model.text_model(input_ids, attention_mask).pooler_output
|
105 |
+
|
106 |
+
if proj_norm:
|
107 |
+
text_feat = clip_model.text_projection(text_feat)
|
108 |
+
text_feat = text_feat / text_feat.norm(dim=-1, keepdim=True)
|
109 |
+
|
110 |
+
return text_feat
|
111 |
+
|
112 |
+
# @torch.no_grad()
|
113 |
+
def calc_clip_s(self, img_feat, text_feat):
|
114 |
+
return self.clipscore_w * torch.relu((img_feat * text_feat).sum(dim=-1))
|
115 |
+
|
116 |
+
# @torch.no_grad()
|
117 |
+
def calc_refclip_s(self, img_feat=None, text_feat=None, ref_text_feat=None, ref_text_mask=None, clip_s=None):
|
118 |
+
|
119 |
+
if clip_s is None:
|
120 |
+
clip_s = self.calc_clip_s(img_feat, text_feat)
|
121 |
+
|
122 |
+
B, dim = img_feat.size()
|
123 |
+
|
124 |
+
ref_text_feat = ref_text_feat.view(B, -1, dim)
|
125 |
+
|
126 |
+
K = ref_text_feat.size(1)
|
127 |
+
|
128 |
+
text_feat = text_feat.view(B, 1, dim).expand(-1, K, -1)
|
129 |
+
assert ref_text_feat.size() == text_feat.size(
|
130 |
+
), (ref_text_feat.size(), text_feat.size())
|
131 |
+
|
132 |
+
ref_score = self.calc_clip_s(text_feat, ref_text_feat)
|
133 |
+
if ref_text_mask is not None:
|
134 |
+
if not isinstance(ref_text_mask, torch.Tensor):
|
135 |
+
ref_text_mask = torch.tensor(
|
136 |
+
ref_text_mask, dtype=ref_score.dtype, device=ref_score.device)
|
137 |
+
ref_score = ref_score.view(B, K) * ref_text_mask.view(B, K)
|
138 |
+
|
139 |
+
ref_score = ref_score.view(B, K).max(dim=1).values
|
140 |
+
|
141 |
+
assert clip_s.size() == (B,)
|
142 |
+
assert clip_s.size() == ref_score.size()
|
143 |
+
|
144 |
+
# harmonic mean
|
145 |
+
refclip_s = 2 / (1 / clip_s + 1 / ref_score)
|
146 |
+
return refclip_s
|
147 |
+
|
148 |
+
@torch.no_grad()
|
149 |
+
def forward(self,
|
150 |
+
images=None, text=None,
|
151 |
+
img_feat=None, text_feat=None,
|
152 |
+
ref_text=None, ref_text_feat=None, ref_text_mask=None,
|
153 |
+
prompt="A photo depicts",
|
154 |
+
mode=None):
|
155 |
+
if img_feat is None:
|
156 |
+
img_feat = self.image_extract(images)
|
157 |
+
img_feat = img_feat.view(-1, 512)
|
158 |
+
|
159 |
+
B = img_feat.size(0)
|
160 |
+
|
161 |
+
if text_feat is None:
|
162 |
+
text_feat = self.text_extract(text, prompt=prompt)
|
163 |
+
text_feat = text_feat.view(-1, 512)
|
164 |
+
|
165 |
+
if mode is None:
|
166 |
+
mode = self.mode
|
167 |
+
assert mode in ['clip_s', 'refclip_s']
|
168 |
+
|
169 |
+
if mode == 'clip_s':
|
170 |
+
clip_s = self.calc_clip_s(img_feat, text_feat)
|
171 |
+
return clip_s
|
172 |
+
elif mode == 'refclip_s':
|
173 |
+
if ref_text_feat is None:
|
174 |
+
ref_text_feat = self.text_extract(ref_text, prompt=prompt)
|
175 |
+
ref_text_feat = ref_text_feat.view(-1, 512)
|
176 |
+
|
177 |
+
refclip_s = self.calc_refclip_s(
|
178 |
+
img_feat, text_feat, ref_text_feat, ref_text_mask=ref_text_mask)
|
179 |
+
return refclip_s
|
180 |
+
|
181 |
+
|
182 |
+
def train_step(self,
|
183 |
+
images=None, text=None,
|
184 |
+
img_feat=None, text_feat=None,
|
185 |
+
neg_text=None, neg_text_feat=None,
|
186 |
+
# ref_text=None, ref_text_feat=None, ref_text_mask=None,
|
187 |
+
prompt="A photo depicts",
|
188 |
+
# return_loss=True,
|
189 |
+
**kwargs):
|
190 |
+
|
191 |
+
if img_feat is None:
|
192 |
+
img_feat = self.image_extract(images)
|
193 |
+
img_feat = img_feat.view(-1, 512)
|
194 |
+
|
195 |
+
B = img_feat.size(0)
|
196 |
+
|
197 |
+
if text_feat is None:
|
198 |
+
text_feat = self.text_extract(text, prompt=prompt, proj_norm=False)
|
199 |
+
|
200 |
+
text_cont_feat = self.clip_model.text_projection(text_feat)
|
201 |
+
text_cont_feat = text_cont_feat / text_cont_feat.norm(dim=-1, keepdim=True)
|
202 |
+
text_cont_feat = text_cont_feat.view(B, 512)
|
203 |
+
|
204 |
+
# cosine similarity as logits
|
205 |
+
logit_scale = self.clip_model.logit_scale.exp()
|
206 |
+
logits_per_text = torch.matmul(text_cont_feat, img_feat.t()) * logit_scale
|
207 |
+
# logits_per_image = logits_per_text.T
|
208 |
+
|
209 |
+
clip_loss = clip_loss_fn(logits_per_text)
|
210 |
+
|
211 |
+
|
212 |
+
# negative sampling
|
213 |
+
pos_text_feat = text_feat.view(B, 512)
|
214 |
+
neg_text_feat = self.text_extract(neg_text, prompt=prompt, proj_norm=False).view(B, 512)
|
215 |
+
|
216 |
+
grammar_text_feat = torch.cat([pos_text_feat, neg_text_feat], dim=0)
|
217 |
+
|
218 |
+
# 2B, 1
|
219 |
+
grammar_text_logit = self.grammar_score_head(grammar_text_feat)
|
220 |
+
grammar_labels = torch.LongTensor([1] * B + [0] * B).to(grammar_text_logit.device).view(2 * B)
|
221 |
+
|
222 |
+
grammar_loss = torch.nn.functional.cross_entropy(grammar_text_logit, grammar_labels)
|
223 |
+
|
224 |
+
grammar_pred = grammar_text_logit.argmax(dim=1, keepdim=False)
|
225 |
+
grammar_pos_pred = grammar_pred[:B]
|
226 |
+
grammar_neg_pred = grammar_pred[B:]
|
227 |
+
# grammar_acc = (grammar_pred == grammar_labels).float().mean()
|
228 |
+
|
229 |
+
out = {
|
230 |
+
'clip_loss': clip_loss,
|
231 |
+
'grammar_loss': grammar_loss,
|
232 |
+
'img_feat': img_feat,
|
233 |
+
'text_feat': text_cont_feat,
|
234 |
+
'neg_text_feat': neg_text_feat,
|
235 |
+
'grammar_pos_pred': grammar_pos_pred,
|
236 |
+
'grammar_neg_pred': grammar_neg_pred,
|
237 |
+
}
|
238 |
+
|
239 |
+
return out
|
240 |
+
|
241 |
+
# contrastive loss function, adapted from
|
242 |
+
# https://sachinruk.github.io/blog/pytorch/pytorch%20lightning/loss%20function/gpu/2021/03/07/CLIP.html
|
243 |
+
def contrastive_loss(logits: torch.Tensor, dim: int) -> torch.Tensor:
|
244 |
+
neg_ce = torch.diag(nn.functional.log_softmax(logits, dim=dim))
|
245 |
+
return -neg_ce.mean()
|
246 |
+
|
247 |
+
|
248 |
+
def clip_loss_fn(similarity: torch.Tensor) -> torch.Tensor:
|
249 |
+
caption_loss = contrastive_loss(similarity, dim=0)
|
250 |
+
image_loss = contrastive_loss(similarity, dim=1)
|
251 |
+
return (caption_loss + image_loss) / 2.0
|
252 |
+
|
253 |
+
|
254 |
+
|
255 |
+
# class CLIPScore(nn.Module):
|
256 |
+
# def __init__(self, clipscore_w=2.5, image_size=224, mode='clip_s'):
|
257 |
+
# super(CLIPScore, self).__init__()
|
258 |
+
# # from transformers import CLIPModel, CLIPTokenizer
|
259 |
+
# self.clip_model = CLIPModel.from_pretrained(
|
260 |
+
# 'openai/clip-vit-base-patch32')
|
261 |
+
# self.tokenizer = CLIPTokenizer.from_pretrained(
|
262 |
+
# 'openai/clip-vit-base-patch32')
|
263 |
+
|
264 |
+
# self.clip_model.eval()
|
265 |
+
|
266 |
+
# self.clipscore_w = clipscore_w
|
267 |
+
|
268 |
+
# self.image_transform = self._transform(image_size)
|
269 |
+
|
270 |
+
# self.mode = mode
|
271 |
+
# assert mode in ['clip_s', 'refclip_s']
|
272 |
+
|
273 |
+
# def _transform(self, n_px):
|
274 |
+
# return Compose([
|
275 |
+
# Resize(n_px, interpolation=Image.BICUBIC),
|
276 |
+
# CenterCrop(n_px),
|
277 |
+
# lambda image: image.convert("RGB"),
|
278 |
+
# ToTensor(),
|
279 |
+
# Normalize((0.48145466, 0.4578275, 0.40821073),
|
280 |
+
# (0.26862954, 0.26130258, 0.27577711)),
|
281 |
+
# ])
|
282 |
+
|
283 |
+
# def load_image(self, image_path):
|
284 |
+
# image = Image.open(image_path)
|
285 |
+
# return image
|
286 |
+
|
287 |
+
# @torch.no_grad()
|
288 |
+
# def image_extract(self, image):
|
289 |
+
# if isinstance(image, str):
|
290 |
+
# image = self.load_image(image)
|
291 |
+
# if not isinstance(image, torch.Tensor):
|
292 |
+
# image = self.image_transform(image)
|
293 |
+
|
294 |
+
# img_tensor = image.view(-1, 3, 224, 224)
|
295 |
+
# device = next(self.clip_model.parameters()).device
|
296 |
+
# img_tensor = img_tensor.to(device)
|
297 |
+
|
298 |
+
# clip_model = self.clip_model
|
299 |
+
|
300 |
+
# img_feat = clip_model.vision_model(img_tensor).pooler_output
|
301 |
+
# img_feat = clip_model.visual_projection(img_feat)
|
302 |
+
# img_feat = img_feat / img_feat.norm(dim=-1, keepdim=True)
|
303 |
+
|
304 |
+
# return img_feat
|
305 |
+
|
306 |
+
# @torch.no_grad()
|
307 |
+
# def text_extract(self, text, prompt="A photo depicts"):
|
308 |
+
# if isinstance(text, str):
|
309 |
+
# text_batch = [" ".join([prompt, text])]
|
310 |
+
# else:
|
311 |
+
# text_batch = [" ".join([prompt, txt]) for txt in text]
|
312 |
+
|
313 |
+
# input_text = text_batch
|
314 |
+
|
315 |
+
# tokenized = self.tokenizer(
|
316 |
+
# input_text, return_tensors='pt', padding=True)
|
317 |
+
|
318 |
+
# input_ids = tokenized.input_ids
|
319 |
+
# attention_mask = tokenized.attention_mask
|
320 |
+
|
321 |
+
# clip_model = self.clip_model
|
322 |
+
# device = next(self.clip_model.parameters()).device
|
323 |
+
# input_ids = input_ids.to(device)
|
324 |
+
# attention_mask = attention_mask.to(device)
|
325 |
+
|
326 |
+
# text_feat = clip_model.text_model(input_ids, attention_mask).pooler_output
|
327 |
+
# text_feat = clip_model.text_projection(text_feat)
|
328 |
+
# text_feat = text_feat / text_feat.norm(dim=-1, keepdim=True)
|
329 |
+
|
330 |
+
# return text_feat
|
331 |
+
|
332 |
+
# @torch.no_grad()
|
333 |
+
# def calc_clip_s(self, img_feat, text_feat):
|
334 |
+
# return self.clipscore_w * torch.relu((img_feat * text_feat).sum(dim=-1))
|
335 |
+
|
336 |
+
# @torch.no_grad()
|
337 |
+
# def calc_refclip_s(self, img_feat=None, text_feat=None, ref_text_feat=None, ref_text_mask=None, clip_s=None):
|
338 |
+
|
339 |
+
# if clip_s is None:
|
340 |
+
# clip_s = self.calc_clip_s(img_feat, text_feat)
|
341 |
+
|
342 |
+
# B, dim = img_feat.size()
|
343 |
+
|
344 |
+
# ref_text_feat = ref_text_feat.view(B, -1, dim)
|
345 |
+
|
346 |
+
# K = ref_text_feat.size(1)
|
347 |
+
|
348 |
+
# text_feat = text_feat.view(B, 1, dim).expand(-1, K, -1)
|
349 |
+
# assert ref_text_feat.size() == text_feat.size(), (ref_text_feat.size(), text_feat.size())
|
350 |
+
|
351 |
+
# ref_score = self.calc_clip_s(text_feat, ref_text_feat)
|
352 |
+
# if ref_text_mask is not None:
|
353 |
+
# if not isinstance(ref_text_mask, torch.Tensor):
|
354 |
+
# ref_text_mask = torch.tensor(ref_text_mask, dtype=ref_score.dtype, device=ref_score.device)
|
355 |
+
# ref_score = ref_score.view(B, K) * ref_text_mask.view(B, K)
|
356 |
+
|
357 |
+
# ref_score = ref_score.view(B, K).max(dim=1).values
|
358 |
+
|
359 |
+
# assert clip_s.size() == (B,)
|
360 |
+
# assert clip_s.size() == ref_score.size()
|
361 |
+
|
362 |
+
# # harmonic mean
|
363 |
+
# refclip_s = 2 / (1 / clip_s + 1 / ref_score)
|
364 |
+
# return refclip_s
|
365 |
+
|
366 |
+
|
367 |
+
# @torch.no_grad()
|
368 |
+
# def forward(self,
|
369 |
+
# images=None, text=None,
|
370 |
+
# img_feat=None, text_feat=None,
|
371 |
+
# ref_text=None, ref_text_feat=None, ref_text_mask=None,
|
372 |
+
# prompt="A photo depicts",
|
373 |
+
# mode=None):
|
374 |
+
# if img_feat is None:
|
375 |
+
# img_feat = self.image_extract(images)
|
376 |
+
# img_feat = img_feat.view(-1, 512)
|
377 |
+
|
378 |
+
# if text_feat is None:
|
379 |
+
# text_feat = self.text_extract(text, prompt=prompt)
|
380 |
+
# text_feat = text_feat.view(-1, 512)
|
381 |
+
|
382 |
+
# if mode is None:
|
383 |
+
# mode = self.mode
|
384 |
+
# assert mode in ['clip_s', 'refclip_s']
|
385 |
+
|
386 |
+
# if mode == 'clip_s':
|
387 |
+
# clip_s = self.calc_clip_s(img_feat, text_feat)
|
388 |
+
# return clip_s
|
389 |
+
# elif mode == 'refclip_s':
|
390 |
+
# if ref_text_feat is None:
|
391 |
+
# ref_text_feat = self.text_extract(ref_text, prompt=prompt)
|
392 |
+
# ref_text_feat = ref_text_feat.view(-1, 512)
|
393 |
+
|
394 |
+
# refclip_s = self.calc_refclip_s(img_feat, text_feat, ref_text_feat, ref_text_mask=ref_text_mask)
|
395 |
+
# return refclip_s
|
396 |
+
|
captioning/utils/config.py
ADDED
@@ -0,0 +1,153 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
|
2 |
+
# Copy from fvcore
|
3 |
+
|
4 |
+
import logging
|
5 |
+
import os
|
6 |
+
from typing import Any
|
7 |
+
import yaml
|
8 |
+
from yacs.config import CfgNode as _CfgNode
|
9 |
+
|
10 |
+
import io as PathManager
|
11 |
+
|
12 |
+
BASE_KEY = "_BASE_"
|
13 |
+
|
14 |
+
|
15 |
+
class CfgNode(_CfgNode):
|
16 |
+
"""
|
17 |
+
Our own extended version of :class:`yacs.config.CfgNode`.
|
18 |
+
It contains the following extra features:
|
19 |
+
|
20 |
+
1. The :meth:`merge_from_file` method supports the "_BASE_" key,
|
21 |
+
which allows the new CfgNode to inherit all the attributes from the
|
22 |
+
base configuration file.
|
23 |
+
2. Keys that start with "COMPUTED_" are treated as insertion-only
|
24 |
+
"computed" attributes. They can be inserted regardless of whether
|
25 |
+
the CfgNode is frozen or not.
|
26 |
+
3. With "allow_unsafe=True", it supports pyyaml tags that evaluate
|
27 |
+
expressions in config. See examples in
|
28 |
+
https://pyyaml.org/wiki/PyYAMLDocumentation#yaml-tags-and-python-types
|
29 |
+
Note that this may lead to arbitrary code execution: you must not
|
30 |
+
load a config file from untrusted sources before manually inspecting
|
31 |
+
the content of the file.
|
32 |
+
"""
|
33 |
+
|
34 |
+
@staticmethod
|
35 |
+
def load_yaml_with_base(filename, allow_unsafe = False):
|
36 |
+
"""
|
37 |
+
Just like `yaml.load(open(filename))`, but inherit attributes from its
|
38 |
+
`_BASE_`.
|
39 |
+
|
40 |
+
Args:
|
41 |
+
filename (str): the file name of the current config. Will be used to
|
42 |
+
find the base config file.
|
43 |
+
allow_unsafe (bool): whether to allow loading the config file with
|
44 |
+
`yaml.unsafe_load`.
|
45 |
+
|
46 |
+
Returns:
|
47 |
+
(dict): the loaded yaml
|
48 |
+
"""
|
49 |
+
with PathManager.open(filename, "r") as f:
|
50 |
+
try:
|
51 |
+
cfg = yaml.safe_load(f)
|
52 |
+
except yaml.constructor.ConstructorError:
|
53 |
+
if not allow_unsafe:
|
54 |
+
raise
|
55 |
+
logger = logging.getLogger(__name__)
|
56 |
+
logger.warning(
|
57 |
+
"Loading config {} with yaml.unsafe_load. Your machine may "
|
58 |
+
"be at risk if the file contains malicious content.".format(
|
59 |
+
filename
|
60 |
+
)
|
61 |
+
)
|
62 |
+
f.close()
|
63 |
+
with open(filename, "r") as f:
|
64 |
+
cfg = yaml.unsafe_load(f)
|
65 |
+
|
66 |
+
def merge_a_into_b(a, b):
|
67 |
+
# merge dict a into dict b. values in a will overwrite b.
|
68 |
+
for k, v in a.items():
|
69 |
+
if isinstance(v, dict) and k in b:
|
70 |
+
assert isinstance(
|
71 |
+
b[k], dict
|
72 |
+
), "Cannot inherit key '{}' from base!".format(k)
|
73 |
+
merge_a_into_b(v, b[k])
|
74 |
+
else:
|
75 |
+
b[k] = v
|
76 |
+
|
77 |
+
if BASE_KEY in cfg:
|
78 |
+
base_cfg_file = cfg[BASE_KEY]
|
79 |
+
if base_cfg_file.startswith("~"):
|
80 |
+
base_cfg_file = os.path.expanduser(base_cfg_file)
|
81 |
+
if not any(
|
82 |
+
map(base_cfg_file.startswith, ["/", "https://", "http://"])
|
83 |
+
):
|
84 |
+
# the path to base cfg is relative to the config file itself.
|
85 |
+
base_cfg_file = os.path.join(
|
86 |
+
os.path.dirname(filename), base_cfg_file
|
87 |
+
)
|
88 |
+
base_cfg = CfgNode.load_yaml_with_base(
|
89 |
+
base_cfg_file, allow_unsafe=allow_unsafe
|
90 |
+
)
|
91 |
+
del cfg[BASE_KEY]
|
92 |
+
|
93 |
+
merge_a_into_b(cfg, base_cfg)
|
94 |
+
return base_cfg
|
95 |
+
return cfg
|
96 |
+
|
97 |
+
def merge_from_file(self, cfg_filename, allow_unsafe = False):
|
98 |
+
"""
|
99 |
+
Merge configs from a given yaml file.
|
100 |
+
|
101 |
+
Args:
|
102 |
+
cfg_filename: the file name of the yaml config.
|
103 |
+
allow_unsafe: whether to allow loading the config file with
|
104 |
+
`yaml.unsafe_load`.
|
105 |
+
"""
|
106 |
+
loaded_cfg = CfgNode.load_yaml_with_base(
|
107 |
+
cfg_filename, allow_unsafe=allow_unsafe
|
108 |
+
)
|
109 |
+
loaded_cfg = type(self)(loaded_cfg)
|
110 |
+
self.merge_from_other_cfg(loaded_cfg)
|
111 |
+
|
112 |
+
# Forward the following calls to base, but with a check on the BASE_KEY.
|
113 |
+
def merge_from_other_cfg(self, cfg_other):
|
114 |
+
"""
|
115 |
+
Args:
|
116 |
+
cfg_other (CfgNode): configs to merge from.
|
117 |
+
"""
|
118 |
+
assert (
|
119 |
+
BASE_KEY not in cfg_other
|
120 |
+
), "The reserved key '{}' can only be used in files!".format(BASE_KEY)
|
121 |
+
return super().merge_from_other_cfg(cfg_other)
|
122 |
+
|
123 |
+
def merge_from_list(self, cfg_list):
|
124 |
+
"""
|
125 |
+
Args:
|
126 |
+
cfg_list (list): list of configs to merge from.
|
127 |
+
"""
|
128 |
+
keys = set(cfg_list[0::2])
|
129 |
+
assert (
|
130 |
+
BASE_KEY not in keys
|
131 |
+
), "The reserved key '{}' can only be used in files!".format(BASE_KEY)
|
132 |
+
return super().merge_from_list(cfg_list)
|
133 |
+
|
134 |
+
def __setattr__(self, name, val):
|
135 |
+
if name.startswith("COMPUTED_"):
|
136 |
+
if name in self:
|
137 |
+
old_val = self[name]
|
138 |
+
if old_val == val:
|
139 |
+
return
|
140 |
+
raise KeyError(
|
141 |
+
"Computed attributed '{}' already exists "
|
142 |
+
"with a different value! old={}, new={}.".format(
|
143 |
+
name, old_val, val
|
144 |
+
)
|
145 |
+
)
|
146 |
+
self[name] = val
|
147 |
+
else:
|
148 |
+
super().__setattr__(name, val)
|
149 |
+
|
150 |
+
|
151 |
+
if __name__ == '__main__':
|
152 |
+
cfg = CfgNode.load_yaml_with_base('configs/updown_long.yml')
|
153 |
+
print(cfg)
|
captioning/utils/dist_utils.py
ADDED
@@ -0,0 +1,305 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
2 |
+
"""
|
3 |
+
This file contains primitives for multi-gpu communication.
|
4 |
+
This is useful when doing distributed training.
|
5 |
+
"""
|
6 |
+
|
7 |
+
import functools
|
8 |
+
import logging
|
9 |
+
import numpy as np
|
10 |
+
import pickle
|
11 |
+
import torch
|
12 |
+
import torch.distributed as dist
|
13 |
+
|
14 |
+
import torch
|
15 |
+
|
16 |
+
_LOCAL_PROCESS_GROUP = None
|
17 |
+
"""
|
18 |
+
A torch process group which only includes processes that on the same machine as the current process.
|
19 |
+
This variable is set when processes are spawned by `launch()` in "engine/launch.py".
|
20 |
+
"""
|
21 |
+
|
22 |
+
|
23 |
+
def get_world_size() -> int:
|
24 |
+
if not dist.is_available():
|
25 |
+
return 1
|
26 |
+
if not dist.is_initialized():
|
27 |
+
return 1
|
28 |
+
return dist.get_world_size()
|
29 |
+
|
30 |
+
|
31 |
+
def get_rank() -> int:
|
32 |
+
if not dist.is_available():
|
33 |
+
return 0
|
34 |
+
if not dist.is_initialized():
|
35 |
+
return 0
|
36 |
+
return dist.get_rank()
|
37 |
+
|
38 |
+
|
39 |
+
def get_local_rank() -> int:
|
40 |
+
"""
|
41 |
+
Returns:
|
42 |
+
The rank of the current process within the local (per-machine) process group.
|
43 |
+
"""
|
44 |
+
if not dist.is_available():
|
45 |
+
return 0
|
46 |
+
if not dist.is_initialized():
|
47 |
+
return 0
|
48 |
+
assert _LOCAL_PROCESS_GROUP is not None
|
49 |
+
return dist.get_rank(group=_LOCAL_PROCESS_GROUP)
|
50 |
+
|
51 |
+
|
52 |
+
def get_local_size() -> int:
|
53 |
+
"""
|
54 |
+
Returns:
|
55 |
+
The size of the per-machine process group,
|
56 |
+
i.e. the number of processes per machine.
|
57 |
+
"""
|
58 |
+
if not dist.is_available():
|
59 |
+
return 1
|
60 |
+
if not dist.is_initialized():
|
61 |
+
return 1
|
62 |
+
return dist.get_world_size(group=_LOCAL_PROCESS_GROUP)
|
63 |
+
|
64 |
+
|
65 |
+
def is_main_process() -> bool:
|
66 |
+
return get_rank() == 0
|
67 |
+
|
68 |
+
|
69 |
+
def synchronize():
|
70 |
+
"""
|
71 |
+
Helper function to synchronize (barrier) among all processes when
|
72 |
+
using distributed training
|
73 |
+
"""
|
74 |
+
if not dist.is_available():
|
75 |
+
return
|
76 |
+
if not dist.is_initialized():
|
77 |
+
return
|
78 |
+
world_size = dist.get_world_size()
|
79 |
+
if world_size == 1:
|
80 |
+
return
|
81 |
+
dist.barrier()
|
82 |
+
|
83 |
+
|
84 |
+
@functools.lru_cache()
|
85 |
+
def _get_global_gloo_group():
|
86 |
+
"""
|
87 |
+
Return a process group based on gloo backend, containing all the ranks
|
88 |
+
The result is cached.
|
89 |
+
"""
|
90 |
+
if dist.get_backend() == "nccl":
|
91 |
+
return dist.new_group(backend="gloo")
|
92 |
+
else:
|
93 |
+
return dist.group.WORLD
|
94 |
+
|
95 |
+
|
96 |
+
def _serialize_to_tensor(data, group):
|
97 |
+
backend = dist.get_backend(group)
|
98 |
+
assert backend in ["gloo", "nccl"]
|
99 |
+
device = torch.device("cpu" if backend == "gloo" else "cuda")
|
100 |
+
|
101 |
+
buffer = pickle.dumps(data)
|
102 |
+
if len(buffer) > 1024 ** 3:
|
103 |
+
logger = logging.getLogger(__name__)
|
104 |
+
logger.warning(
|
105 |
+
"Rank {} trying to all-gather {:.2f} GB of data on device {}".format(
|
106 |
+
get_rank(), len(buffer) / (1024 ** 3), device
|
107 |
+
)
|
108 |
+
)
|
109 |
+
storage = torch.ByteStorage.from_buffer(buffer)
|
110 |
+
tensor = torch.ByteTensor(storage).to(device=device)
|
111 |
+
return tensor
|
112 |
+
|
113 |
+
|
114 |
+
def _pad_to_largest_tensor(tensor, group):
|
115 |
+
"""
|
116 |
+
Returns:
|
117 |
+
list[int]: size of the tensor, on each rank
|
118 |
+
Tensor: padded tensor that has the max size
|
119 |
+
"""
|
120 |
+
world_size = dist.get_world_size(group=group)
|
121 |
+
assert (
|
122 |
+
world_size >= 1
|
123 |
+
), "comm.gather/all_gather must be called from ranks within the given group!"
|
124 |
+
local_size = torch.tensor(
|
125 |
+
[tensor.numel()], dtype=torch.int64, device=tensor.device)
|
126 |
+
size_list = [
|
127 |
+
torch.zeros([1], dtype=torch.int64, device=tensor.device)
|
128 |
+
for _ in range(world_size)
|
129 |
+
]
|
130 |
+
dist.all_gather(size_list, local_size, group=group)
|
131 |
+
size_list = [int(size.item()) for size in size_list]
|
132 |
+
|
133 |
+
max_size = max(size_list)
|
134 |
+
|
135 |
+
# we pad the tensor because torch all_gather does not support
|
136 |
+
# gathering tensors of different shapes
|
137 |
+
if local_size != max_size:
|
138 |
+
padding = torch.zeros(
|
139 |
+
(max_size - local_size,), dtype=torch.uint8, device=tensor.device
|
140 |
+
)
|
141 |
+
tensor = torch.cat((tensor, padding), dim=0)
|
142 |
+
return size_list, tensor
|
143 |
+
|
144 |
+
|
145 |
+
def all_gather(data, group=None):
|
146 |
+
"""
|
147 |
+
Run all_gather on arbitrary picklable data (not necessarily tensors).
|
148 |
+
Args:
|
149 |
+
data: any picklable object
|
150 |
+
group: a torch process group. By default, will use a group which
|
151 |
+
contains all ranks on gloo backend.
|
152 |
+
Returns:
|
153 |
+
list[data]: list of data gathered from each rank
|
154 |
+
"""
|
155 |
+
if get_world_size() == 1:
|
156 |
+
return [data]
|
157 |
+
if group is None:
|
158 |
+
group = _get_global_gloo_group()
|
159 |
+
if dist.get_world_size(group) == 1:
|
160 |
+
return [data]
|
161 |
+
|
162 |
+
tensor = _serialize_to_tensor(data, group)
|
163 |
+
|
164 |
+
size_list, tensor = _pad_to_largest_tensor(tensor, group)
|
165 |
+
max_size = max(size_list)
|
166 |
+
|
167 |
+
# receiving Tensor from all ranks
|
168 |
+
tensor_list = [
|
169 |
+
torch.empty((max_size,), dtype=torch.uint8, device=tensor.device)
|
170 |
+
for _ in size_list
|
171 |
+
]
|
172 |
+
dist.all_gather(tensor_list, tensor, group=group)
|
173 |
+
|
174 |
+
data_list = []
|
175 |
+
for size, tensor in zip(size_list, tensor_list):
|
176 |
+
buffer = tensor.cpu().numpy().tobytes()[:size]
|
177 |
+
data_list.append(pickle.loads(buffer))
|
178 |
+
|
179 |
+
return data_list
|
180 |
+
|
181 |
+
|
182 |
+
def gather(data, dst=0, group=None):
|
183 |
+
"""
|
184 |
+
Run gather on arbitrary picklable data (not necessarily tensors).
|
185 |
+
Args:
|
186 |
+
data: any picklable object
|
187 |
+
dst (int): destination rank
|
188 |
+
group: a torch process group. By default, will use a group which
|
189 |
+
contains all ranks on gloo backend.
|
190 |
+
Returns:
|
191 |
+
list[data]: on dst, a list of data gathered from each rank. Otherwise,
|
192 |
+
an empty list.
|
193 |
+
"""
|
194 |
+
if get_world_size() == 1:
|
195 |
+
return [data]
|
196 |
+
if group is None:
|
197 |
+
group = _get_global_gloo_group()
|
198 |
+
if dist.get_world_size(group=group) == 1:
|
199 |
+
return [data]
|
200 |
+
rank = dist.get_rank(group=group)
|
201 |
+
|
202 |
+
tensor = _serialize_to_tensor(data, group)
|
203 |
+
size_list, tensor = _pad_to_largest_tensor(tensor, group)
|
204 |
+
|
205 |
+
# receiving Tensor from all ranks
|
206 |
+
if rank == dst:
|
207 |
+
max_size = max(size_list)
|
208 |
+
tensor_list = [
|
209 |
+
torch.empty((max_size,), dtype=torch.uint8, device=tensor.device)
|
210 |
+
for _ in size_list
|
211 |
+
]
|
212 |
+
dist.gather(tensor, tensor_list, dst=dst, group=group)
|
213 |
+
|
214 |
+
data_list = []
|
215 |
+
for size, tensor in zip(size_list, tensor_list):
|
216 |
+
buffer = tensor.cpu().numpy().tobytes()[:size]
|
217 |
+
data_list.append(pickle.loads(buffer))
|
218 |
+
return data_list
|
219 |
+
else:
|
220 |
+
dist.gather(tensor, [], dst=dst, group=group)
|
221 |
+
return []
|
222 |
+
|
223 |
+
|
224 |
+
def shared_random_seed():
|
225 |
+
"""
|
226 |
+
Returns:
|
227 |
+
int: a random number that is the same across all workers.
|
228 |
+
If workers need a shared RNG, they can use this shared seed to
|
229 |
+
create one.
|
230 |
+
All workers must call this function, otherwise it will deadlock.
|
231 |
+
"""
|
232 |
+
ints = np.random.randint(2 ** 31)
|
233 |
+
all_ints = all_gather(ints)
|
234 |
+
return all_ints[0]
|
235 |
+
|
236 |
+
|
237 |
+
# def reduce_dict(input_dict, average=True):
|
238 |
+
# """
|
239 |
+
# Reduce the values in the dictionary from all processes so that process with rank
|
240 |
+
# 0 has the reduced results.
|
241 |
+
# Args:
|
242 |
+
# input_dict (dict): inputs to be reduced. All the values must be scalar CUDA Tensor.
|
243 |
+
# average (bool): whether to do average or sum
|
244 |
+
# Returns:
|
245 |
+
# a dict with the same keys as input_dict, after reduction.
|
246 |
+
# """
|
247 |
+
# world_size = get_world_size()
|
248 |
+
# if world_size < 2:
|
249 |
+
# return input_dict
|
250 |
+
# with torch.no_grad():
|
251 |
+
# names = []
|
252 |
+
# values = []
|
253 |
+
# # sort the keys so that they are consistent across processes
|
254 |
+
# for k in sorted(input_dict.keys()):
|
255 |
+
# names.append(k)
|
256 |
+
# values.append(input_dict[k])
|
257 |
+
# values = torch.stack(values, dim=0)
|
258 |
+
# dist.reduce(values, dst=0)
|
259 |
+
# if dist.get_rank() == 0 and average:
|
260 |
+
# # only main process gets accumulated, so only divide by
|
261 |
+
# # world_size in this case
|
262 |
+
# values /= world_size
|
263 |
+
# reduced_dict = {k: v for k, v in zip(names, values)}
|
264 |
+
# return reduced_dict
|
265 |
+
|
266 |
+
|
267 |
+
def reduce_dict(input_dict, average=True):
|
268 |
+
"""
|
269 |
+
Reduce the values in the dictionary from all processes so that process with rank
|
270 |
+
0 has the reduced results.
|
271 |
+
Args:
|
272 |
+
input_dict (dict): inputs to be reduced. (values not necessarily tensors).
|
273 |
+
average (bool): whether to do average or sum
|
274 |
+
Returns:
|
275 |
+
a dict with the same keys as input_dict, after reduction.
|
276 |
+
"""
|
277 |
+
|
278 |
+
world_size = get_world_size()
|
279 |
+
if world_size < 2:
|
280 |
+
return input_dict
|
281 |
+
|
282 |
+
with torch.no_grad():
|
283 |
+
|
284 |
+
# Convert to CUDA Tensor for dist.reduce()
|
285 |
+
input_dict_cuda_vals = {}
|
286 |
+
for k, v in input_dict.items():
|
287 |
+
if type(v) == torch.Tensor:
|
288 |
+
input_dict_cuda_vals[k] = v.to('cuda')
|
289 |
+
else:
|
290 |
+
input_dict_cuda_vals[k] = torch.tensor(v, device='cuda')
|
291 |
+
|
292 |
+
names = []
|
293 |
+
values = []
|
294 |
+
for k, v in sorted(input_dict_cuda_vals.items()):
|
295 |
+
names.append(k)
|
296 |
+
values.append(v)
|
297 |
+
values = torch.stack(values, dim=0)
|
298 |
+
dist.reduce(values, dst=0) # reduce to gpu 0
|
299 |
+
|
300 |
+
if dist.get_rank() == 0 and average:
|
301 |
+
# only main process gets accumulated, so only divide by
|
302 |
+
# world_size in this case
|
303 |
+
values /= world_size
|
304 |
+
reduced_dict = {k: v for k, v in zip(names, values)}
|
305 |
+
return reduced_dict
|
captioning/utils/div_utils.py
ADDED
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from random import uniform
|
2 |
+
import numpy as np
|
3 |
+
from collections import OrderedDict, defaultdict
|
4 |
+
from itertools import tee
|
5 |
+
import time
|
6 |
+
|
7 |
+
# -----------------------------------------------
|
8 |
+
def find_ngrams(input_list, n):
|
9 |
+
return zip(*[input_list[i:] for i in range(n)])
|
10 |
+
|
11 |
+
def compute_div_n(caps,n=1):
|
12 |
+
aggr_div = []
|
13 |
+
for k in caps:
|
14 |
+
all_ngrams = set()
|
15 |
+
lenT = 0.
|
16 |
+
for c in caps[k]:
|
17 |
+
tkns = c.split()
|
18 |
+
lenT += len(tkns)
|
19 |
+
ng = find_ngrams(tkns, n)
|
20 |
+
all_ngrams.update(ng)
|
21 |
+
aggr_div.append(float(len(all_ngrams))/ (1e-6 + float(lenT)))
|
22 |
+
return np.array(aggr_div).mean(), np.array(aggr_div)
|
23 |
+
|
24 |
+
def compute_global_div_n(caps,n=1):
|
25 |
+
aggr_div = []
|
26 |
+
all_ngrams = set()
|
27 |
+
lenT = 0.
|
28 |
+
for k in caps:
|
29 |
+
for c in caps[k]:
|
30 |
+
tkns = c.split()
|
31 |
+
lenT += len(tkns)
|
32 |
+
ng = find_ngrams(tkns, n)
|
33 |
+
all_ngrams.update(ng)
|
34 |
+
if n == 1:
|
35 |
+
aggr_div.append(float(len(all_ngrams)))
|
36 |
+
else:
|
37 |
+
aggr_div.append(float(len(all_ngrams))/ (1e-6 + float(lenT)))
|
38 |
+
return aggr_div[0], np.repeat(np.array(aggr_div),len(caps))
|
captioning/utils/eval_multi.py
ADDED
@@ -0,0 +1,218 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import absolute_import
|
2 |
+
from __future__ import division
|
3 |
+
from __future__ import print_function
|
4 |
+
|
5 |
+
import torch
|
6 |
+
import torch.nn as nn
|
7 |
+
|
8 |
+
import numpy as np
|
9 |
+
import json
|
10 |
+
from json import encoder
|
11 |
+
import random
|
12 |
+
import string
|
13 |
+
import time
|
14 |
+
import os
|
15 |
+
import sys
|
16 |
+
from . import misc as utils
|
17 |
+
from eval_utils import getCOCO
|
18 |
+
|
19 |
+
from .div_utils import compute_div_n, compute_global_div_n
|
20 |
+
|
21 |
+
import sys
|
22 |
+
try:
|
23 |
+
sys.path.append("coco-caption")
|
24 |
+
annFile = 'coco-caption/annotations/captions_val2014.json'
|
25 |
+
from pycocotools.coco import COCO
|
26 |
+
from pycocoevalcap.eval import COCOEvalCap
|
27 |
+
from pycocoevalcap.eval_spice import COCOEvalCapSpice
|
28 |
+
from pycocoevalcap.tokenizer.ptbtokenizer import PTBTokenizer
|
29 |
+
from pycocoevalcap.bleu.bleu import Bleu
|
30 |
+
sys.path.append("cider")
|
31 |
+
from pyciderevalcap.cider.cider import Cider
|
32 |
+
except:
|
33 |
+
print('Warning: requirements for eval_multi not satisfied')
|
34 |
+
|
35 |
+
|
36 |
+
def eval_allspice(dataset, preds_n, model_id, split):
|
37 |
+
coco = getCOCO(dataset)
|
38 |
+
valids = coco.getImgIds()
|
39 |
+
|
40 |
+
capsById = {}
|
41 |
+
for d in preds_n:
|
42 |
+
capsById[d['image_id']] = capsById.get(d['image_id'], []) + [d]
|
43 |
+
|
44 |
+
# filter results to only those in MSCOCO validation set (will be about a third)
|
45 |
+
preds_filt_n = [p for p in preds_n if p['image_id'] in valids]
|
46 |
+
print('using %d/%d predictions_n' % (len(preds_filt_n), len(preds_n)))
|
47 |
+
cache_path_n = os.path.join('eval_results/', model_id + '_' + split + '_n.json')
|
48 |
+
json.dump(preds_filt_n, open(cache_path_n, 'w')) # serialize to temporary json file. Sigh, COCO API...
|
49 |
+
|
50 |
+
# Eval AllSPICE
|
51 |
+
cocoRes_n = coco.loadRes(cache_path_n)
|
52 |
+
cocoEvalAllSPICE = COCOEvalCapSpice(coco, cocoRes_n)
|
53 |
+
cocoEvalAllSPICE.params['image_id'] = cocoRes_n.getImgIds()
|
54 |
+
cocoEvalAllSPICE.evaluate()
|
55 |
+
|
56 |
+
out = {}
|
57 |
+
for metric, score in cocoEvalAllSPICE.eval.items():
|
58 |
+
out['All'+metric] = score
|
59 |
+
|
60 |
+
imgToEvalAllSPICE = cocoEvalAllSPICE.imgToEval
|
61 |
+
# collect SPICE_sub_score
|
62 |
+
for k in list(imgToEvalAllSPICE.values())[0]['SPICE'].keys():
|
63 |
+
if k != 'All':
|
64 |
+
out['AllSPICE_'+k] = np.array([v['SPICE'][k]['f'] for v in imgToEvalAllSPICE.values()])
|
65 |
+
out['AllSPICE_'+k] = (out['AllSPICE_'+k][out['AllSPICE_'+k]==out['AllSPICE_'+k]]).mean()
|
66 |
+
for p in preds_filt_n:
|
67 |
+
image_id, caption = p['image_id'], p['caption']
|
68 |
+
imgToEvalAllSPICE[image_id]['caption'] = capsById[image_id]
|
69 |
+
return {'overall': out, 'imgToEvalAllSPICE': imgToEvalAllSPICE}
|
70 |
+
|
71 |
+
def eval_oracle(dataset, preds_n, model_id, split):
|
72 |
+
cache_path = os.path.join('eval_results/', model_id + '_' + split + '_n.json')
|
73 |
+
|
74 |
+
coco = getCOCO(dataset)
|
75 |
+
valids = coco.getImgIds()
|
76 |
+
|
77 |
+
capsById = {}
|
78 |
+
for d in preds_n:
|
79 |
+
capsById[d['image_id']] = capsById.get(d['image_id'], []) + [d]
|
80 |
+
|
81 |
+
sample_n = capsById[list(capsById.keys())[0]]
|
82 |
+
for i in range(len(capsById[list(capsById.keys())[0]])):
|
83 |
+
preds = [_[i] for _ in capsById.values()]
|
84 |
+
|
85 |
+
json.dump(preds, open(cache_path, 'w')) # serialize to temporary json file. Sigh, COCO API...
|
86 |
+
|
87 |
+
cocoRes = coco.loadRes(cache_path)
|
88 |
+
cocoEval = COCOEvalCap(coco, cocoRes)
|
89 |
+
cocoEval.params['image_id'] = cocoRes.getImgIds()
|
90 |
+
cocoEval.evaluate()
|
91 |
+
|
92 |
+
imgToEval = cocoEval.imgToEval
|
93 |
+
for img_id in capsById.keys():
|
94 |
+
tmp = imgToEval[img_id]
|
95 |
+
for k in tmp['SPICE'].keys():
|
96 |
+
if k != 'All':
|
97 |
+
tmp['SPICE_'+k] = tmp['SPICE'][k]['f']
|
98 |
+
if tmp['SPICE_'+k] != tmp['SPICE_'+k]: # nan
|
99 |
+
tmp['SPICE_'+k] = -100
|
100 |
+
tmp['SPICE'] = tmp['SPICE']['All']['f']
|
101 |
+
if tmp['SPICE'] != tmp['SPICE']: tmp['SPICE'] = -100
|
102 |
+
capsById[img_id][i]['scores'] = imgToEval[img_id]
|
103 |
+
|
104 |
+
out = {'overall': {}, 'ImgToEval': {}}
|
105 |
+
for img_id in capsById.keys():
|
106 |
+
out['ImgToEval'][img_id] = {}
|
107 |
+
for metric in capsById[img_id][0]['scores'].keys():
|
108 |
+
if metric == 'image_id': continue
|
109 |
+
out['ImgToEval'][img_id]['oracle_'+metric] = max([_['scores'][metric] for _ in capsById[img_id]])
|
110 |
+
out['ImgToEval'][img_id]['avg_'+metric] = sum([_['scores'][metric] for _ in capsById[img_id]]) / len(capsById[img_id])
|
111 |
+
out['ImgToEval'][img_id]['captions'] = capsById[img_id]
|
112 |
+
for metric in list(out['ImgToEval'].values())[0].keys():
|
113 |
+
if metric == 'captions':
|
114 |
+
continue
|
115 |
+
tmp = np.array([_[metric] for _ in out['ImgToEval'].values()])
|
116 |
+
tmp = tmp[tmp!=-100]
|
117 |
+
out['overall'][metric] = tmp.mean()
|
118 |
+
|
119 |
+
return out
|
120 |
+
|
121 |
+
def eval_div_stats(dataset, preds_n, model_id, split):
|
122 |
+
tokenizer = PTBTokenizer()
|
123 |
+
|
124 |
+
capsById = {}
|
125 |
+
for i, d in enumerate(preds_n):
|
126 |
+
d['id'] = i
|
127 |
+
capsById[d['image_id']] = capsById.get(d['image_id'], []) + [d]
|
128 |
+
|
129 |
+
n_caps_perimg = len(capsById[list(capsById.keys())[0]])
|
130 |
+
print(n_caps_perimg)
|
131 |
+
_capsById = capsById # save the untokenized version
|
132 |
+
capsById = tokenizer.tokenize(capsById)
|
133 |
+
|
134 |
+
div_1, adiv_1 = compute_div_n(capsById,1)
|
135 |
+
div_2, adiv_2 = compute_div_n(capsById,2)
|
136 |
+
|
137 |
+
globdiv_1, _= compute_global_div_n(capsById,1)
|
138 |
+
|
139 |
+
print('Diversity Statistics are as follows: \n Div1: %.2f, Div2: %.2f, gDiv1: %d\n'%(div_1,div_2, globdiv_1))
|
140 |
+
|
141 |
+
# compute mbleu
|
142 |
+
scorer = Bleu(4)
|
143 |
+
all_scrs = []
|
144 |
+
scrperimg = np.zeros((n_caps_perimg, len(capsById)))
|
145 |
+
|
146 |
+
for i in range(n_caps_perimg):
|
147 |
+
tempRefsById = {}
|
148 |
+
candsById = {}
|
149 |
+
for k in capsById:
|
150 |
+
tempRefsById[k] = capsById[k][:i] + capsById[k][i+1:]
|
151 |
+
candsById[k] = [capsById[k][i]]
|
152 |
+
|
153 |
+
score, scores = scorer.compute_score(tempRefsById, candsById)
|
154 |
+
all_scrs.append(score)
|
155 |
+
scrperimg[i,:] = scores[1]
|
156 |
+
|
157 |
+
all_scrs = np.array(all_scrs)
|
158 |
+
|
159 |
+
out = {}
|
160 |
+
out['overall'] = {'Div1': div_1, 'Div2': div_2, 'gDiv1': globdiv_1}
|
161 |
+
for k, score in zip(range(4), all_scrs.mean(axis=0).tolist()):
|
162 |
+
out['overall'].update({'mBLeu_%d'%(k+1): score})
|
163 |
+
imgToEval = {}
|
164 |
+
for i,imgid in enumerate(capsById.keys()):
|
165 |
+
imgToEval[imgid] = {'mBleu_2' : scrperimg[:,i].mean()}
|
166 |
+
imgToEval[imgid]['individuals'] = []
|
167 |
+
for j, d in enumerate(_capsById[imgid]):
|
168 |
+
imgToEval[imgid]['individuals'].append(preds_n[d['id']])
|
169 |
+
imgToEval[imgid]['individuals'][-1]['mBleu_2'] = scrperimg[j,i]
|
170 |
+
out['ImgToEval'] = imgToEval
|
171 |
+
|
172 |
+
print('Mean mutual Bleu scores on this set is:\nmBLeu_1, mBLeu_2, mBLeu_3, mBLeu_4')
|
173 |
+
print(all_scrs.mean(axis=0))
|
174 |
+
|
175 |
+
return out
|
176 |
+
|
177 |
+
def eval_self_cider(dataset, preds_n, model_id, split):
|
178 |
+
cache_path = os.path.join('eval_results/', model_id + '_' + split + '_n.json')
|
179 |
+
|
180 |
+
coco = getCOCO(dataset)
|
181 |
+
valids = coco.getImgIds()
|
182 |
+
|
183 |
+
# Get Cider_scorer
|
184 |
+
Cider_scorer = Cider(df='corpus')
|
185 |
+
|
186 |
+
tokenizer = PTBTokenizer()
|
187 |
+
gts = {}
|
188 |
+
for imgId in valids:
|
189 |
+
gts[imgId] = coco.imgToAnns[imgId]
|
190 |
+
gts = tokenizer.tokenize(gts)
|
191 |
+
|
192 |
+
for imgId in valids:
|
193 |
+
Cider_scorer.cider_scorer += (None, gts[imgId])
|
194 |
+
Cider_scorer.cider_scorer.compute_doc_freq()
|
195 |
+
Cider_scorer.cider_scorer.ref_len = np.log(float(len(Cider_scorer.cider_scorer.crefs)))
|
196 |
+
|
197 |
+
# Prepare captions
|
198 |
+
capsById = {}
|
199 |
+
for d in preds_n:
|
200 |
+
capsById[d['image_id']] = capsById.get(d['image_id'], []) + [d]
|
201 |
+
|
202 |
+
capsById = tokenizer.tokenize(capsById)
|
203 |
+
imgIds = list(capsById.keys())
|
204 |
+
scores = Cider_scorer.my_self_cider([capsById[_] for _ in imgIds])
|
205 |
+
|
206 |
+
def get_div(eigvals):
|
207 |
+
eigvals = np.clip(eigvals, 0, None)
|
208 |
+
return -np.log(np.sqrt(eigvals[-1]) / (np.sqrt(eigvals).sum())) / np.log(len(eigvals))
|
209 |
+
sc_scores = [get_div(np.linalg.eigvalsh(_/10)) for _ in scores]
|
210 |
+
score = np.mean(np.array(sc_scores))
|
211 |
+
|
212 |
+
imgToEval = {}
|
213 |
+
for i, image_id in enumerate(imgIds):
|
214 |
+
imgToEval[image_id] = {'self_cider': sc_scores[i], 'self_cider_mat': scores[i].tolist()}
|
215 |
+
return {'overall': {'self_cider': score}, 'imgToEval': imgToEval}
|
216 |
+
|
217 |
+
|
218 |
+
return score
|
captioning/utils/eval_utils.py
ADDED
@@ -0,0 +1,281 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import absolute_import
|
2 |
+
from __future__ import division
|
3 |
+
from __future__ import print_function
|
4 |
+
|
5 |
+
import torch
|
6 |
+
import torch.nn as nn
|
7 |
+
import torch.nn.functional as F
|
8 |
+
|
9 |
+
import numpy as np
|
10 |
+
import json
|
11 |
+
from json import encoder
|
12 |
+
import random
|
13 |
+
import string
|
14 |
+
import time
|
15 |
+
import os
|
16 |
+
import sys
|
17 |
+
from . import misc as utils
|
18 |
+
|
19 |
+
# load coco-caption if available
|
20 |
+
try:
|
21 |
+
sys.path.append("coco-caption")
|
22 |
+
from pycocotools.coco import COCO
|
23 |
+
from pycocoevalcap.eval import COCOEvalCap
|
24 |
+
except:
|
25 |
+
print('Warning: coco-caption not available')
|
26 |
+
|
27 |
+
bad_endings = ['a','an','the','in','for','at','of','with','before','after','on','upon','near','to','is','are','am']
|
28 |
+
bad_endings += ['the']
|
29 |
+
|
30 |
+
|
31 |
+
def count_bad(sen):
|
32 |
+
sen = sen.split(' ')
|
33 |
+
if sen[-1] in bad_endings:
|
34 |
+
return 1
|
35 |
+
else:
|
36 |
+
return 0
|
37 |
+
|
38 |
+
|
39 |
+
def getCOCO(dataset):
|
40 |
+
if 'coco' in dataset:
|
41 |
+
annFile = 'coco-caption/annotations/captions_val2014.json'
|
42 |
+
elif 'flickr30k' in dataset or 'f30k' in dataset:
|
43 |
+
annFile = 'data/f30k_captions4eval.json'
|
44 |
+
return COCO(annFile)
|
45 |
+
|
46 |
+
|
47 |
+
def language_eval(dataset, preds, preds_n, eval_kwargs, split):
|
48 |
+
model_id = eval_kwargs['id']
|
49 |
+
eval_oracle = eval_kwargs.get('eval_oracle', 0)
|
50 |
+
|
51 |
+
# create output dictionary
|
52 |
+
out = {}
|
53 |
+
|
54 |
+
if len(preds_n) > 0:
|
55 |
+
# vocab size and novel sentences
|
56 |
+
if 'coco' in dataset:
|
57 |
+
dataset_file = 'data/dataset_coco.json'
|
58 |
+
elif 'flickr30k' in dataset or 'f30k' in dataset:
|
59 |
+
dataset_file = 'data/dataset_flickr30k.json'
|
60 |
+
training_sentences = set([' '.join(__['tokens']) for _ in json.load(open(dataset_file))['images'] if not _['split'] in ['val', 'test'] for __ in _['sentences']])
|
61 |
+
generated_sentences = set([_['caption'] for _ in preds_n])
|
62 |
+
novels = generated_sentences - training_sentences
|
63 |
+
out['novel_sentences'] = float(len(novels)) / len(preds_n)
|
64 |
+
tmp = [_.split() for _ in generated_sentences]
|
65 |
+
words = []
|
66 |
+
for _ in tmp:
|
67 |
+
words += _
|
68 |
+
out['vocab_size'] = len(set(words))
|
69 |
+
|
70 |
+
# encoder.FLOAT_REPR = lambda o: format(o, '.3f')
|
71 |
+
|
72 |
+
cache_path = os.path.join('eval_results/', '.cache_'+ model_id + '_' + split + '.json')
|
73 |
+
|
74 |
+
coco = getCOCO(dataset)
|
75 |
+
valids = coco.getImgIds()
|
76 |
+
|
77 |
+
# filter results to only those in MSCOCO validation set
|
78 |
+
preds_filt = [p for p in preds if p['image_id'] in valids]
|
79 |
+
mean_perplexity = sum([_['perplexity'] for _ in preds_filt]) / len(preds_filt)
|
80 |
+
mean_entropy = sum([_['entropy'] for _ in preds_filt]) / len(preds_filt)
|
81 |
+
print('using %d/%d predictions' % (len(preds_filt), len(preds)))
|
82 |
+
json.dump(preds_filt, open(cache_path, 'w')) # serialize to temporary json file. Sigh, COCO API...
|
83 |
+
|
84 |
+
cocoRes = coco.loadRes(cache_path)
|
85 |
+
cocoEval = COCOEvalCap(coco, cocoRes)
|
86 |
+
cocoEval.params['image_id'] = cocoRes.getImgIds()
|
87 |
+
cocoEval.evaluate()
|
88 |
+
|
89 |
+
for metric, score in cocoEval.eval.items():
|
90 |
+
out[metric] = score
|
91 |
+
# Add mean perplexity
|
92 |
+
out['perplexity'] = mean_perplexity
|
93 |
+
out['entropy'] = mean_entropy
|
94 |
+
|
95 |
+
imgToEval = cocoEval.imgToEval
|
96 |
+
for k in list(imgToEval.values())[0]['SPICE'].keys():
|
97 |
+
if k != 'All':
|
98 |
+
out['SPICE_'+k] = np.array([v['SPICE'][k]['f'] for v in imgToEval.values()])
|
99 |
+
out['SPICE_'+k] = (out['SPICE_'+k][out['SPICE_'+k]==out['SPICE_'+k]]).mean()
|
100 |
+
for p in preds_filt:
|
101 |
+
image_id, caption = p['image_id'], p['caption']
|
102 |
+
imgToEval[image_id]['caption'] = caption
|
103 |
+
|
104 |
+
if len(preds_n) > 0:
|
105 |
+
from . import eval_multi
|
106 |
+
cache_path_n = os.path.join('eval_results/', '.cache_'+ model_id + '_' + split + '_n.json')
|
107 |
+
allspice = eval_multi.eval_allspice(dataset, preds_n, model_id, split)
|
108 |
+
out.update(allspice['overall'])
|
109 |
+
div_stats = eval_multi.eval_div_stats(dataset, preds_n, model_id, split)
|
110 |
+
out.update(div_stats['overall'])
|
111 |
+
if eval_oracle:
|
112 |
+
oracle = eval_multi.eval_oracle(dataset, preds_n, model_id, split)
|
113 |
+
out.update(oracle['overall'])
|
114 |
+
else:
|
115 |
+
oracle = None
|
116 |
+
self_cider = eval_multi.eval_self_cider(dataset, preds_n, model_id, split)
|
117 |
+
out.update(self_cider['overall'])
|
118 |
+
with open(cache_path_n, 'w') as outfile:
|
119 |
+
json.dump({'allspice': allspice, 'div_stats': div_stats, 'oracle': oracle, 'self_cider': self_cider}, outfile)
|
120 |
+
|
121 |
+
out['bad_count_rate'] = sum([count_bad(_['caption']) for _ in preds_filt]) / float(len(preds_filt))
|
122 |
+
outfile_path = os.path.join('eval_results/', model_id + '_' + split + '.json')
|
123 |
+
with open(outfile_path, 'w') as outfile:
|
124 |
+
json.dump({'overall': out, 'imgToEval': imgToEval}, outfile)
|
125 |
+
|
126 |
+
return out
|
127 |
+
|
128 |
+
def eval_split(model, crit, loader, eval_kwargs={}):
|
129 |
+
verbose = eval_kwargs.get('verbose', True)
|
130 |
+
verbose_beam = eval_kwargs.get('verbose_beam', 0)
|
131 |
+
verbose_loss = eval_kwargs.get('verbose_loss', 1)
|
132 |
+
num_images = eval_kwargs.get('num_images', eval_kwargs.get('val_images_use', -1))
|
133 |
+
split = eval_kwargs.get('split', 'val')
|
134 |
+
lang_eval = eval_kwargs.get('language_eval', 0)
|
135 |
+
dataset = eval_kwargs.get('dataset', 'coco')
|
136 |
+
beam_size = eval_kwargs.get('beam_size', 1)
|
137 |
+
sample_n = eval_kwargs.get('sample_n', 1)
|
138 |
+
remove_bad_endings = eval_kwargs.get('remove_bad_endings', 0)
|
139 |
+
os.environ["REMOVE_BAD_ENDINGS"] = str(remove_bad_endings) # Use this nasty way to make other code clean since it's a global configuration
|
140 |
+
device = eval_kwargs.get('device', 'cuda')
|
141 |
+
|
142 |
+
# Make sure in the evaluation mode
|
143 |
+
model.eval()
|
144 |
+
|
145 |
+
loader.reset_iterator(split)
|
146 |
+
|
147 |
+
n = 0
|
148 |
+
loss = 0
|
149 |
+
loss_sum = 0
|
150 |
+
loss_evals = 1e-8
|
151 |
+
predictions = []
|
152 |
+
n_predictions = [] # when sample_n > 1
|
153 |
+
while True:
|
154 |
+
data = loader.get_batch(split)
|
155 |
+
n = n + len(data['infos'])
|
156 |
+
|
157 |
+
tmp = [data['fc_feats'], data['att_feats'], data['labels'], data['masks'], data['att_masks']]
|
158 |
+
tmp = [_.to(device) if _ is not None else _ for _ in tmp]
|
159 |
+
fc_feats, att_feats, labels, masks, att_masks = tmp
|
160 |
+
if labels is not None and verbose_loss:
|
161 |
+
# forward the model to get loss
|
162 |
+
with torch.no_grad():
|
163 |
+
loss = crit(model(fc_feats, att_feats, labels[..., :-1], att_masks), labels[..., 1:], masks[..., 1:]).item()
|
164 |
+
loss_sum = loss_sum + loss
|
165 |
+
loss_evals = loss_evals + 1
|
166 |
+
|
167 |
+
# forward the model to also get generated samples for each image
|
168 |
+
with torch.no_grad():
|
169 |
+
tmp_eval_kwargs = eval_kwargs.copy()
|
170 |
+
tmp_eval_kwargs.update({'sample_n': 1})
|
171 |
+
seq, seq_logprobs = model(fc_feats, att_feats, att_masks, opt=tmp_eval_kwargs, mode='sample')
|
172 |
+
seq = seq.data
|
173 |
+
entropy = - (F.softmax(seq_logprobs, dim=2) * seq_logprobs).sum(2).sum(1) / ((seq>0).to(seq_logprobs).sum(1)+1)
|
174 |
+
perplexity = - seq_logprobs.gather(2, seq.unsqueeze(2)).squeeze(2).sum(1) / ((seq>0).to(seq_logprobs).sum(1)+1)
|
175 |
+
|
176 |
+
# Print beam search
|
177 |
+
if beam_size > 1 and verbose_beam:
|
178 |
+
for i in range(fc_feats.shape[0]):
|
179 |
+
print('\n'.join([utils.decode_sequence(model.vocab, _['seq'].unsqueeze(0))[0] for _ in model.done_beams[i]]))
|
180 |
+
print('--' * 10)
|
181 |
+
sents = utils.decode_sequence(model.vocab, seq)
|
182 |
+
|
183 |
+
for k, sent in enumerate(sents):
|
184 |
+
entry = {'image_id': data['infos'][k]['id'], 'caption': sent, 'perplexity': perplexity[k].item(), 'entropy': entropy[k].item()}
|
185 |
+
if eval_kwargs.get('dump_path', 0) == 1:
|
186 |
+
entry['file_name'] = data['infos'][k]['file_path']
|
187 |
+
predictions.append(entry)
|
188 |
+
if eval_kwargs.get('dump_images', 0) == 1:
|
189 |
+
# dump the raw image to vis/ folder
|
190 |
+
cmd = 'cp "' + os.path.join(eval_kwargs['image_root'], data['infos'][k]['file_path']) + '" vis/imgs/img' + str(len(predictions)) + '.jpg' # bit gross
|
191 |
+
print(cmd)
|
192 |
+
os.system(cmd)
|
193 |
+
|
194 |
+
if verbose:
|
195 |
+
print('image %s: %s' %(entry['image_id'], entry['caption']))
|
196 |
+
|
197 |
+
if sample_n > 1:
|
198 |
+
eval_split_n(model, n_predictions, [fc_feats, att_feats, att_masks, data], eval_kwargs)
|
199 |
+
|
200 |
+
# ix0 = data['bounds']['it_pos_now']
|
201 |
+
ix1 = data['bounds']['it_max']
|
202 |
+
if num_images != -1:
|
203 |
+
ix1 = min(ix1, num_images)
|
204 |
+
else:
|
205 |
+
num_images = ix1
|
206 |
+
for i in range(n - ix1):
|
207 |
+
predictions.pop()
|
208 |
+
|
209 |
+
if verbose:
|
210 |
+
print('evaluating validation preformance... %d/%d (%f)' %(n, ix1, loss))
|
211 |
+
|
212 |
+
if num_images >= 0 and n >= num_images:
|
213 |
+
break
|
214 |
+
|
215 |
+
lang_stats = None
|
216 |
+
if len(n_predictions) > 0 and 'perplexity' in n_predictions[0]:
|
217 |
+
n_predictions = sorted(n_predictions, key=lambda x: x['perplexity'])
|
218 |
+
if not os.path.isdir('eval_results'):
|
219 |
+
os.mkdir('eval_results')
|
220 |
+
torch.save((predictions, n_predictions), os.path.join('eval_results/', '.saved_pred_'+ eval_kwargs['id'] + '_' + split + '.pth'))
|
221 |
+
if lang_eval == 1:
|
222 |
+
lang_stats = language_eval(dataset, predictions, n_predictions, eval_kwargs, split)
|
223 |
+
|
224 |
+
# Switch back to training mode
|
225 |
+
model.train()
|
226 |
+
return loss_sum/loss_evals, predictions, lang_stats
|
227 |
+
|
228 |
+
|
229 |
+
# Only run when sample_n > 0
|
230 |
+
def eval_split_n(model, n_predictions, input_data, eval_kwargs={}):
|
231 |
+
verbose = eval_kwargs.get('verbose', True)
|
232 |
+
beam_size = eval_kwargs.get('beam_size', 1)
|
233 |
+
sample_n = eval_kwargs.get('sample_n', 1)
|
234 |
+
sample_n_method = eval_kwargs.get('sample_n_method', 'sample')
|
235 |
+
|
236 |
+
fc_feats, att_feats, att_masks, data = input_data
|
237 |
+
|
238 |
+
tmp_eval_kwargs = eval_kwargs.copy()
|
239 |
+
if sample_n_method == 'bs':
|
240 |
+
# case 1 sample_n == beam size
|
241 |
+
tmp_eval_kwargs.update({'sample_n': 1, 'beam_size': sample_n, 'group_size': 1}) # randomness from softmax
|
242 |
+
with torch.no_grad():
|
243 |
+
model(fc_feats, att_feats, att_masks, opt=tmp_eval_kwargs, mode='sample')
|
244 |
+
for k in range(fc_feats.shape[0]):
|
245 |
+
_sents = utils.decode_sequence(model.vocab, torch.stack([model.done_beams[k][_]['seq'] for _ in range(sample_n)]))
|
246 |
+
for sent in _sents:
|
247 |
+
entry = {'image_id': data['infos'][k]['id'], 'caption': sent}
|
248 |
+
n_predictions.append(entry)
|
249 |
+
# case 2 sample / gumbel / topk sampling/ nucleus sampling
|
250 |
+
elif sample_n_method == 'sample' or \
|
251 |
+
sample_n_method == 'gumbel' or \
|
252 |
+
sample_n_method.startswith('top'):
|
253 |
+
tmp_eval_kwargs.update({'sample_n': sample_n, 'sample_method': sample_n_method, 'beam_size': 1}) # randomness from sample
|
254 |
+
with torch.no_grad():
|
255 |
+
_seq, _sampleLogprobs = model(fc_feats, att_feats, att_masks, opt=tmp_eval_kwargs, mode='sample')
|
256 |
+
_sents = utils.decode_sequence(model.vocab, _seq)
|
257 |
+
_perplexity = - _sampleLogprobs.gather(2, _seq.unsqueeze(2)).squeeze(2).sum(1) / ((_seq>0).to(_sampleLogprobs).sum(1)+1)
|
258 |
+
for k, sent in enumerate(_sents):
|
259 |
+
entry = {'image_id': data['infos'][k // sample_n]['id'], 'caption': sent, 'perplexity': _perplexity[k].item()}
|
260 |
+
n_predictions.append(entry)
|
261 |
+
elif sample_n_method == 'dbs':
|
262 |
+
# Use diverse beam search
|
263 |
+
tmp_eval_kwargs.update({'beam_size': sample_n * beam_size, 'group_size': sample_n}) # randomness from softmax
|
264 |
+
with torch.no_grad():
|
265 |
+
model(fc_feats, att_feats, att_masks, opt=tmp_eval_kwargs, mode='sample')
|
266 |
+
for k in range(loader.batch_size):
|
267 |
+
_sents = utils.decode_sequence(model.vocab, torch.stack([model.done_beams[k][_]['seq'] for _ in range(0, sample_n*beam_size, beam_size)]))
|
268 |
+
for sent in _sents:
|
269 |
+
entry = {'image_id': data['infos'][k]['id'], 'caption': sent}
|
270 |
+
n_predictions.append(entry)
|
271 |
+
else:
|
272 |
+
tmp_eval_kwargs.update({'sample_method': sample_n_method[1:], 'group_size': sample_n, 'beam_size':1}) # randomness from softmax
|
273 |
+
with torch.no_grad():
|
274 |
+
_seq, _sampleLogprobs = model(fc_feats, att_feats, att_masks, opt=tmp_eval_kwargs, mode='sample')
|
275 |
+
_sents = utils.decode_sequence(model.vocab, _seq)
|
276 |
+
for k, sent in enumerate(_sents):
|
277 |
+
entry = {'image_id': data['infos'][k // sample_n]['id'], 'caption': sent}
|
278 |
+
n_predictions.append(entry)
|
279 |
+
if verbose:
|
280 |
+
for entry in sorted(n_predictions[-fc_feats.shape[0] * sample_n:], key=lambda x: x['image_id']):
|
281 |
+
print('image %s: %s' %(entry['image_id'], entry['caption']))
|
captioning/utils/misc.py
ADDED
@@ -0,0 +1,251 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import absolute_import
|
2 |
+
from __future__ import division
|
3 |
+
from __future__ import print_function
|
4 |
+
|
5 |
+
import collections
|
6 |
+
import torch
|
7 |
+
import torch.nn as nn
|
8 |
+
import numpy as np
|
9 |
+
import torch.optim as optim
|
10 |
+
import os
|
11 |
+
|
12 |
+
import torch.nn.functional as F
|
13 |
+
|
14 |
+
import six
|
15 |
+
from six.moves import cPickle
|
16 |
+
|
17 |
+
bad_endings = ['with','in','on','of','a','at','to','for','an','this','his','her','that']
|
18 |
+
bad_endings += ['the']
|
19 |
+
|
20 |
+
|
21 |
+
def pickle_load(f):
|
22 |
+
""" Load a pickle.
|
23 |
+
Parameters
|
24 |
+
----------
|
25 |
+
f: file-like object
|
26 |
+
"""
|
27 |
+
if six.PY3:
|
28 |
+
return cPickle.load(f, encoding='latin-1')
|
29 |
+
else:
|
30 |
+
return cPickle.load(f)
|
31 |
+
|
32 |
+
|
33 |
+
def pickle_dump(obj, f):
|
34 |
+
""" Dump a pickle.
|
35 |
+
Parameters
|
36 |
+
----------
|
37 |
+
obj: pickled object
|
38 |
+
f: file-like object
|
39 |
+
"""
|
40 |
+
if six.PY3:
|
41 |
+
return cPickle.dump(obj, f, protocol=2)
|
42 |
+
else:
|
43 |
+
return cPickle.dump(obj, f)
|
44 |
+
|
45 |
+
|
46 |
+
# modified from https://github.com/facebookresearch/detectron2/blob/master/detectron2/utils/comm.py
|
47 |
+
def serialize_to_tensor(data):
|
48 |
+
device = torch.device("cpu")
|
49 |
+
|
50 |
+
buffer = cPickle.dumps(data)
|
51 |
+
storage = torch.ByteStorage.from_buffer(buffer)
|
52 |
+
tensor = torch.ByteTensor(storage).to(device=device)
|
53 |
+
return tensor
|
54 |
+
|
55 |
+
|
56 |
+
def deserialize(tensor):
|
57 |
+
buffer = tensor.cpu().numpy().tobytes()
|
58 |
+
return cPickle.loads(buffer)
|
59 |
+
|
60 |
+
|
61 |
+
# Input: seq, N*D numpy array, with element 0 .. vocab_size. 0 is END token.
|
62 |
+
def decode_sequence(ix_to_word, seq):
|
63 |
+
# N, D = seq.size()
|
64 |
+
N, D = seq.shape
|
65 |
+
out = []
|
66 |
+
for i in range(N):
|
67 |
+
txt = ''
|
68 |
+
for j in range(D):
|
69 |
+
ix = seq[i,j]
|
70 |
+
if ix > 0 :
|
71 |
+
if j >= 1:
|
72 |
+
txt = txt + ' '
|
73 |
+
txt = txt + ix_to_word[str(ix.item())]
|
74 |
+
else:
|
75 |
+
break
|
76 |
+
if int(os.getenv('REMOVE_BAD_ENDINGS', '0')):
|
77 |
+
flag = 0
|
78 |
+
words = txt.split(' ')
|
79 |
+
for j in range(len(words)):
|
80 |
+
if words[-j-1] not in bad_endings:
|
81 |
+
flag = -j
|
82 |
+
break
|
83 |
+
txt = ' '.join(words[0:len(words)+flag])
|
84 |
+
out.append(txt.replace('@@ ', ''))
|
85 |
+
return out
|
86 |
+
|
87 |
+
|
88 |
+
def save_checkpoint(opt, model, infos, optimizer, histories=None, append=''):
|
89 |
+
if len(append) > 0:
|
90 |
+
append = '-' + append
|
91 |
+
# if checkpoint_path doesn't exist
|
92 |
+
if not os.path.isdir(opt.checkpoint_path):
|
93 |
+
os.makedirs(opt.checkpoint_path)
|
94 |
+
checkpoint_path = os.path.join(opt.checkpoint_path, 'model%s.pth' %(append))
|
95 |
+
torch.save(model.state_dict(), checkpoint_path)
|
96 |
+
print("model saved to {}".format(checkpoint_path))
|
97 |
+
optimizer_path = os.path.join(opt.checkpoint_path, 'optimizer%s.pth' %(append))
|
98 |
+
torch.save(optimizer.state_dict(), optimizer_path)
|
99 |
+
with open(os.path.join(opt.checkpoint_path, 'infos_'+opt.id+'%s.pkl' %(append)), 'wb') as f:
|
100 |
+
pickle_dump(infos, f)
|
101 |
+
if histories:
|
102 |
+
with open(os.path.join(opt.checkpoint_path, 'histories_'+opt.id+'%s.pkl' %(append)), 'wb') as f:
|
103 |
+
pickle_dump(histories, f)
|
104 |
+
|
105 |
+
|
106 |
+
def set_lr(optimizer, lr):
|
107 |
+
for group in optimizer.param_groups:
|
108 |
+
group['lr'] = lr
|
109 |
+
|
110 |
+
def get_lr(optimizer):
|
111 |
+
for group in optimizer.param_groups:
|
112 |
+
return group['lr']
|
113 |
+
|
114 |
+
|
115 |
+
def build_optimizer(params, opt):
|
116 |
+
if opt.optim == 'rmsprop':
|
117 |
+
return optim.RMSprop(params, opt.learning_rate, opt.optim_alpha, opt.optim_epsilon, weight_decay=opt.weight_decay)
|
118 |
+
elif opt.optim == 'adagrad':
|
119 |
+
return optim.Adagrad(params, opt.learning_rate, weight_decay=opt.weight_decay)
|
120 |
+
elif opt.optim == 'sgd':
|
121 |
+
return optim.SGD(params, opt.learning_rate, weight_decay=opt.weight_decay)
|
122 |
+
elif opt.optim == 'sgdm':
|
123 |
+
return optim.SGD(params, opt.learning_rate, opt.optim_alpha, weight_decay=opt.weight_decay)
|
124 |
+
elif opt.optim == 'sgdmom':
|
125 |
+
return optim.SGD(params, opt.learning_rate, opt.optim_alpha, weight_decay=opt.weight_decay, nesterov=True)
|
126 |
+
elif opt.optim == 'adam':
|
127 |
+
return optim.Adam(params, opt.learning_rate, (opt.optim_alpha, opt.optim_beta), opt.optim_epsilon, weight_decay=opt.weight_decay)
|
128 |
+
elif opt.optim == 'adamw':
|
129 |
+
return optim.AdamW(params, opt.learning_rate, (opt.optim_alpha, opt.optim_beta), opt.optim_epsilon, weight_decay=opt.weight_decay)
|
130 |
+
else:
|
131 |
+
raise Exception("bad option opt.optim: {}".format(opt.optim))
|
132 |
+
|
133 |
+
|
134 |
+
def penalty_builder(penalty_config):
|
135 |
+
if penalty_config == '':
|
136 |
+
return lambda x,y: y
|
137 |
+
pen_type, alpha = penalty_config.split('_')
|
138 |
+
alpha = float(alpha)
|
139 |
+
if pen_type == 'wu':
|
140 |
+
return lambda x,y: length_wu(x,y,alpha)
|
141 |
+
if pen_type == 'avg':
|
142 |
+
return lambda x,y: length_average(x,y,alpha)
|
143 |
+
|
144 |
+
def length_wu(length, logprobs, alpha=0.):
|
145 |
+
"""
|
146 |
+
NMT length re-ranking score from
|
147 |
+
"Google's Neural Machine Translation System" :cite:`wu2016google`.
|
148 |
+
"""
|
149 |
+
|
150 |
+
modifier = (((5 + length) ** alpha) /
|
151 |
+
((5 + 1) ** alpha))
|
152 |
+
return (logprobs / modifier)
|
153 |
+
|
154 |
+
def length_average(length, logprobs, alpha=0.):
|
155 |
+
"""
|
156 |
+
Returns the average probability of tokens in a sequence.
|
157 |
+
"""
|
158 |
+
return logprobs / length
|
159 |
+
|
160 |
+
|
161 |
+
class NoamOpt(object):
|
162 |
+
"Optim wrapper that implements rate."
|
163 |
+
def __init__(self, model_size, factor, warmup, optimizer):
|
164 |
+
self.optimizer = optimizer
|
165 |
+
self._step = 0
|
166 |
+
self.warmup = warmup
|
167 |
+
self.factor = factor
|
168 |
+
self.model_size = model_size
|
169 |
+
self._rate = 0
|
170 |
+
|
171 |
+
def step(self):
|
172 |
+
"Update parameters and rate"
|
173 |
+
self._step += 1
|
174 |
+
rate = self.rate()
|
175 |
+
for p in self.optimizer.param_groups:
|
176 |
+
p['lr'] = rate
|
177 |
+
self._rate = rate
|
178 |
+
self.optimizer.step()
|
179 |
+
|
180 |
+
def rate(self, step = None):
|
181 |
+
"Implement `lrate` above"
|
182 |
+
if step is None:
|
183 |
+
step = self._step
|
184 |
+
return self.factor * \
|
185 |
+
(self.model_size ** (-0.5) *
|
186 |
+
min(step ** (-0.5), step * self.warmup ** (-1.5)))
|
187 |
+
|
188 |
+
def __getattr__(self, name):
|
189 |
+
return getattr(self.optimizer, name)
|
190 |
+
|
191 |
+
def state_dict(self):
|
192 |
+
state_dict = self.optimizer.state_dict()
|
193 |
+
state_dict['_step'] = self._step
|
194 |
+
return state_dict
|
195 |
+
|
196 |
+
def load_state_dict(self, state_dict):
|
197 |
+
if '_step' in state_dict:
|
198 |
+
self._step = state_dict['_step']
|
199 |
+
del state_dict['_step']
|
200 |
+
self.optimizer.load_state_dict(state_dict)
|
201 |
+
|
202 |
+
class ReduceLROnPlateau(object):
|
203 |
+
"Optim wrapper that implements rate."
|
204 |
+
def __init__(self, optimizer, mode='min', factor=0.1, patience=10, verbose=False, threshold=0.0001, threshold_mode='rel', cooldown=0, min_lr=0, eps=1e-08):
|
205 |
+
self.scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode, factor, patience, verbose, threshold, threshold_mode, cooldown, min_lr, eps)
|
206 |
+
self.optimizer = optimizer
|
207 |
+
self.current_lr = get_lr(optimizer)
|
208 |
+
|
209 |
+
def step(self):
|
210 |
+
"Update parameters and rate"
|
211 |
+
self.optimizer.step()
|
212 |
+
|
213 |
+
def scheduler_step(self, val):
|
214 |
+
self.scheduler.step(val)
|
215 |
+
self.current_lr = get_lr(self.optimizer)
|
216 |
+
|
217 |
+
def state_dict(self):
|
218 |
+
return {'current_lr':self.current_lr,
|
219 |
+
'scheduler_state_dict': self.scheduler.state_dict(),
|
220 |
+
'optimizer_state_dict': self.optimizer.state_dict()}
|
221 |
+
|
222 |
+
def load_state_dict(self, state_dict):
|
223 |
+
if 'current_lr' not in state_dict:
|
224 |
+
# it's normal optimizer
|
225 |
+
self.optimizer.load_state_dict(state_dict)
|
226 |
+
set_lr(self.optimizer, self.current_lr) # use the lr fromt the option
|
227 |
+
else:
|
228 |
+
# it's a schduler
|
229 |
+
self.current_lr = state_dict['current_lr']
|
230 |
+
self.scheduler.load_state_dict(state_dict['scheduler_state_dict'])
|
231 |
+
self.optimizer.load_state_dict(state_dict['optimizer_state_dict'])
|
232 |
+
# current_lr is actually useless in this case
|
233 |
+
|
234 |
+
def rate(self, step = None):
|
235 |
+
"Implement `lrate` above"
|
236 |
+
if step is None:
|
237 |
+
step = self._step
|
238 |
+
return self.factor * \
|
239 |
+
(self.model_size ** (-0.5) *
|
240 |
+
min(step ** (-0.5), step * self.warmup ** (-1.5)))
|
241 |
+
|
242 |
+
def __getattr__(self, name):
|
243 |
+
return getattr(self.optimizer, name)
|
244 |
+
|
245 |
+
def get_std_opt(model, optim_func='adam', factor=1, warmup=2000):
|
246 |
+
# return NoamOpt(model.tgt_embed[0].d_model, 2, 4000,
|
247 |
+
# torch.optim.Adam(model.parameters(), lr=0, betas=(0.9, 0.98), eps=1e-9))
|
248 |
+
optim_func = dict(adam=torch.optim.Adam,
|
249 |
+
adamw=torch.optim.AdamW)[optim_func]
|
250 |
+
return NoamOpt(model.d_model, factor, warmup,
|
251 |
+
optim_func(model.parameters(), lr=0, betas=(0.9, 0.98), eps=1e-9))
|
captioning/utils/opts.py
ADDED
@@ -0,0 +1,412 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import print_function
|
2 |
+
import argparse
|
3 |
+
|
4 |
+
|
5 |
+
def if_use_feat(caption_model):
|
6 |
+
# Decide if load attention feature according to caption model
|
7 |
+
if caption_model in ['show_tell', 'all_img', 'fc', 'newfc']:
|
8 |
+
use_att, use_fc = False, True
|
9 |
+
elif caption_model == 'language_model':
|
10 |
+
use_att, use_fc = False, False
|
11 |
+
elif caption_model in ['updown', 'topdown']:
|
12 |
+
use_fc, use_att = True, True
|
13 |
+
else:
|
14 |
+
use_att, use_fc = True, False
|
15 |
+
return use_fc, use_att
|
16 |
+
|
17 |
+
import pprint
|
18 |
+
class Config(object):
|
19 |
+
def __init__(self, **kwargs):
|
20 |
+
"""Configuration Class: set kwargs as class attributes with setattr"""
|
21 |
+
for k, v in kwargs.items():
|
22 |
+
setattr(self, k, v)
|
23 |
+
|
24 |
+
@property
|
25 |
+
def config_str(self):
|
26 |
+
return pprint.pformat(self.__dict__)
|
27 |
+
|
28 |
+
def __repr__(self):
|
29 |
+
"""Pretty-print configurations in alphabetical order"""
|
30 |
+
config_str = 'Configurations\n'
|
31 |
+
config_str += self.config_str
|
32 |
+
return config_str
|
33 |
+
|
34 |
+
|
35 |
+
def parse_opt(parse=True, **optional_kwargs):
|
36 |
+
parser = argparse.ArgumentParser()
|
37 |
+
# Data input settings
|
38 |
+
parser.add_argument('--input_json', type=str, default='data/coco.json',
|
39 |
+
help='path to the json file containing additional info and vocab')
|
40 |
+
parser.add_argument('--input_fc_dir', type=str, default='data/cocotalk_fc',
|
41 |
+
help='path to the directory containing the preprocessed fc feats')
|
42 |
+
parser.add_argument('--input_att_dir', type=str, default='data/cocotalk_att',
|
43 |
+
help='path to the directory containing the preprocessed att feats')
|
44 |
+
parser.add_argument('--input_box_dir', type=str, default='data/cocotalk_box',
|
45 |
+
help='path to the directory containing the boxes of att feats')
|
46 |
+
parser.add_argument('--input_label_h5', type=str, default='data/coco_label.h5',
|
47 |
+
help='path to the h5file containing the preprocessed dataset')
|
48 |
+
parser.add_argument('--data_in_memory', action='store_true',
|
49 |
+
help='True if we want to save the features in memory')
|
50 |
+
parser.add_argument('--start_from', type=str, default=None,
|
51 |
+
help="""continue training from saved model at this path. Path must contain files saved by previous training process:
|
52 |
+
'infos.pkl' : configuration;
|
53 |
+
'model.pth' : weights
|
54 |
+
""")
|
55 |
+
parser.add_argument('--cached_tokens', type=str, default='coco-train-idxs',
|
56 |
+
help='Cached token file for calculating cider score during self critical training.')
|
57 |
+
|
58 |
+
# Model settings
|
59 |
+
parser.add_argument('--caption_model', type=str, default="show_tell",
|
60 |
+
help='show_tell, show_attend_tell, all_img, fc, att2in, att2in2, att2all2, adaatt, adaattmo, updown, stackatt, denseatt, transformer')
|
61 |
+
parser.add_argument('--rnn_size', type=int, default=512,
|
62 |
+
help='size of the rnn in number of hidden nodes in each layer')
|
63 |
+
parser.add_argument('--num_layers', type=int, default=1,
|
64 |
+
help='number of layers in the RNN')
|
65 |
+
parser.add_argument('--rnn_type', type=str, default='lstm',
|
66 |
+
help='rnn, gru, or lstm')
|
67 |
+
parser.add_argument('--input_encoding_size', type=int, default=512,
|
68 |
+
help='the encoding size of each token in the vocabulary, and the image.')
|
69 |
+
parser.add_argument('--att_hid_size', type=int, default=512,
|
70 |
+
help='the hidden size of the attention MLP; only useful in show_attend_tell; 0 if not using hidden layer')
|
71 |
+
parser.add_argument('--fc_feat_size', type=int, default=2048,
|
72 |
+
help='2048 for resnet, 4096 for vgg')
|
73 |
+
parser.add_argument('--att_feat_size', type=int, default=2048,
|
74 |
+
help='2048 for resnet, 512 for vgg')
|
75 |
+
parser.add_argument('--logit_layers', type=int, default=1,
|
76 |
+
help='number of layers in the RNN')
|
77 |
+
|
78 |
+
|
79 |
+
parser.add_argument('--use_bn', type=int, default=0,
|
80 |
+
help='If 1, then do batch_normalization first in att_embed, if 2 then do bn both in the beginning and the end of att_embed')
|
81 |
+
|
82 |
+
# feature manipulation
|
83 |
+
parser.add_argument('--norm_att_feat', type=int, default=0,
|
84 |
+
help='If normalize attention features')
|
85 |
+
parser.add_argument('--use_box', type=int, default=0,
|
86 |
+
help='If use box features')
|
87 |
+
parser.add_argument('--norm_box_feat', type=int, default=0,
|
88 |
+
help='If use box, do we normalize box feature')
|
89 |
+
|
90 |
+
# Optimization: General
|
91 |
+
parser.add_argument('--max_epochs', type=int, default=-1,
|
92 |
+
help='number of epochs')
|
93 |
+
parser.add_argument('--batch_size', type=int, default=16,
|
94 |
+
help='minibatch size')
|
95 |
+
parser.add_argument('--grad_clip_mode', type=str, default='value',
|
96 |
+
help='value or norm')
|
97 |
+
parser.add_argument('--grad_clip_value', type=float, default=0.1,
|
98 |
+
help='clip gradients at this value/max_norm, 0 means no clipping')
|
99 |
+
parser.add_argument('--drop_prob_lm', type=float, default=0.5,
|
100 |
+
help='strength of dropout in the Language Model RNN')
|
101 |
+
parser.add_argument('--self_critical_after', type=int, default=-1,
|
102 |
+
help='After what epoch do we start finetuning the CNN? (-1 = disable; never finetune, 0 = finetune from start)')
|
103 |
+
parser.add_argument('--seq_per_img', type=int, default=5,
|
104 |
+
help='number of captions to sample for each image during training. Done for efficiency since CNN forward pass is expensive. E.g. coco has 5 sents/image')
|
105 |
+
|
106 |
+
parser.add_argument('--verbose', type=int, default=0)
|
107 |
+
|
108 |
+
# Sample related
|
109 |
+
add_eval_sample_opts(parser)
|
110 |
+
|
111 |
+
#Optimization: for the Language Model
|
112 |
+
parser.add_argument('--optim', type=str, default='adam',
|
113 |
+
help='what update to use? rmsprop|sgd|sgdmom|adagrad|adam|adamw')
|
114 |
+
parser.add_argument('--learning_rate', type=float, default=4e-4,
|
115 |
+
help='learning rate')
|
116 |
+
parser.add_argument('--learning_rate_decay_start', type=int, default=-1,
|
117 |
+
help='at what iteration to start decaying learning rate? (-1 = dont) (in epoch)')
|
118 |
+
parser.add_argument('--learning_rate_decay_every', type=int, default=3,
|
119 |
+
help='every how many iterations thereafter to drop LR?(in epoch)')
|
120 |
+
parser.add_argument('--learning_rate_decay_rate', type=float, default=0.8,
|
121 |
+
help='every how many iterations thereafter to drop LR?(in epoch)')
|
122 |
+
parser.add_argument('--optim_alpha', type=float, default=0.9,
|
123 |
+
help='alpha for adam')
|
124 |
+
parser.add_argument('--optim_beta', type=float, default=0.999,
|
125 |
+
help='beta used for adam')
|
126 |
+
parser.add_argument('--optim_epsilon', type=float, default=1e-8,
|
127 |
+
help='epsilon that goes into denominator for smoothing')
|
128 |
+
parser.add_argument('--weight_decay', type=float, default=0,
|
129 |
+
help='weight_decay')
|
130 |
+
# Transformer
|
131 |
+
parser.add_argument('--label_smoothing', type=float, default=0,
|
132 |
+
help='')
|
133 |
+
parser.add_argument('--noamopt', action='store_true',
|
134 |
+
help='')
|
135 |
+
parser.add_argument('--noamopt_warmup', type=int, default=2000,
|
136 |
+
help='')
|
137 |
+
parser.add_argument('--noamopt_factor', type=float, default=1,
|
138 |
+
help='')
|
139 |
+
parser.add_argument('--reduce_on_plateau', action='store_true',
|
140 |
+
help='')
|
141 |
+
parser.add_argument('--reduce_on_plateau_factor', type=float, default=0.5,
|
142 |
+
help='')
|
143 |
+
parser.add_argument('--reduce_on_plateau_patience', type=int, default=3,
|
144 |
+
help='')
|
145 |
+
parser.add_argument('--cached_transformer', action='store_true',
|
146 |
+
help='')
|
147 |
+
|
148 |
+
|
149 |
+
parser.add_argument('--use_warmup', action='store_true',
|
150 |
+
help='warm up the learing rate?')
|
151 |
+
|
152 |
+
parser.add_argument('--scheduled_sampling_start', type=int, default=-1,
|
153 |
+
help='at what iteration to start decay gt probability')
|
154 |
+
parser.add_argument('--scheduled_sampling_increase_every', type=int, default=5,
|
155 |
+
help='every how many iterations thereafter to gt probability')
|
156 |
+
parser.add_argument('--scheduled_sampling_increase_prob', type=float, default=0.05,
|
157 |
+
help='How much to update the prob')
|
158 |
+
parser.add_argument('--scheduled_sampling_max_prob', type=float, default=0.25,
|
159 |
+
help='Maximum scheduled sampling prob.')
|
160 |
+
|
161 |
+
|
162 |
+
# Evaluation/Checkpointing
|
163 |
+
parser.add_argument('--val_images_use', type=int, default=3200,
|
164 |
+
help='how many images to use when periodically evaluating the validation loss? (-1 = all)')
|
165 |
+
parser.add_argument('--save_checkpoint_every', type=int, default=2500,
|
166 |
+
help='how often to save a model checkpoint (in iterations)?')
|
167 |
+
parser.add_argument('--save_every_epoch', action='store_true',
|
168 |
+
help='Save checkpoint every epoch, will overwrite save_checkpoint_every')
|
169 |
+
parser.add_argument('--save_history_ckpt', type=int, default=0,
|
170 |
+
help='If save checkpoints at every save point')
|
171 |
+
parser.add_argument('--checkpoint_path', type=str, default=None,
|
172 |
+
help='directory to store checkpointed models')
|
173 |
+
parser.add_argument('--language_eval', type=int, default=0,
|
174 |
+
help='Evaluate language as well (1 = yes, 0 = no)? BLEU/CIDEr/METEOR/ROUGE_L? requires coco-caption code from Github.')
|
175 |
+
parser.add_argument('--losses_log_every', type=int, default=25,
|
176 |
+
help='How often do we snapshot losses, for inclusion in the progress dump? (0 = disable)')
|
177 |
+
parser.add_argument('--load_best_score', type=int, default=1,
|
178 |
+
help='Do we load previous best score when resuming training.')
|
179 |
+
|
180 |
+
# misc
|
181 |
+
parser.add_argument('--id', type=str, default='',
|
182 |
+
help='an id identifying this run/job. used in cross-val and appended when writing progress files')
|
183 |
+
parser.add_argument('--train_only', type=int, default=0,
|
184 |
+
help='if true then use 80k, else use 110k')
|
185 |
+
|
186 |
+
|
187 |
+
# Reward
|
188 |
+
parser.add_argument('--cider_reward_weight', type=float, default=1,
|
189 |
+
help='The reward weight from cider')
|
190 |
+
parser.add_argument('--bleu_reward_weight', type=float, default=0,
|
191 |
+
help='The reward weight from bleu4')
|
192 |
+
|
193 |
+
# Reward
|
194 |
+
parser.add_argument('--clipscore_reward_weight', type=float, default=1,
|
195 |
+
help='The reward weight from clipscore')
|
196 |
+
parser.add_argument('--use_clipscore', type=float, default=0,
|
197 |
+
help='Use CLIPScore')
|
198 |
+
parser.add_argument('--clipscore_mode', type=str, default='clip_s',
|
199 |
+
help='Which CLIPScore to use: clip_s|refclip_s')
|
200 |
+
|
201 |
+
|
202 |
+
# Structure_loss
|
203 |
+
parser.add_argument('--structure_loss_weight', type=float, default=1,
|
204 |
+
help='')
|
205 |
+
parser.add_argument('--structure_after', type=int, default=-1,
|
206 |
+
help='T')
|
207 |
+
parser.add_argument('--structure_loss_type', type=str, default='seqnll',
|
208 |
+
help='')
|
209 |
+
parser.add_argument('--struc_use_logsoftmax', action='store_true', help='')
|
210 |
+
parser.add_argument('--entropy_reward_weight', type=float, default=0,
|
211 |
+
help='Entropy reward, seems very interesting')
|
212 |
+
parser.add_argument('--self_cider_reward_weight', type=float, default=0,
|
213 |
+
help='self cider reward')
|
214 |
+
|
215 |
+
# Used for self critical or structure. Used when sampling is need during training
|
216 |
+
parser.add_argument('--train_sample_n', type=int, default=16,
|
217 |
+
help='The reward weight from cider')
|
218 |
+
parser.add_argument('--train_sample_method', type=str, default='sample',
|
219 |
+
help='')
|
220 |
+
parser.add_argument('--train_beam_size', type=int, default=1,
|
221 |
+
help='')
|
222 |
+
|
223 |
+
# Used for self critical
|
224 |
+
parser.add_argument('--sc_sample_method', type=str, default='greedy',
|
225 |
+
help='')
|
226 |
+
parser.add_argument('--sc_beam_size', type=int, default=1,
|
227 |
+
help='')
|
228 |
+
|
229 |
+
|
230 |
+
# For diversity evaluation during training
|
231 |
+
add_diversity_opts(parser)
|
232 |
+
|
233 |
+
|
234 |
+
# config
|
235 |
+
parser.add_argument('--cfg', type=str, default=None,
|
236 |
+
help='configuration; similar to what is used in detectron')
|
237 |
+
parser.add_argument(
|
238 |
+
'--set_cfgs', dest='set_cfgs',
|
239 |
+
help='Set config keys. Key value sequence seperate by whitespace.'
|
240 |
+
'e.g. [key] [value] [key] [value]\n This has higher priority'
|
241 |
+
'than cfg file but lower than other args. (You can only overwrite'
|
242 |
+
'arguments that have alerady been defined in config file.)',
|
243 |
+
default=[], nargs='+')
|
244 |
+
# How will config be used
|
245 |
+
# 1) read cfg argument, and load the cfg file if it's not None
|
246 |
+
# 2) Overwrite cfg argument with set_cfgs
|
247 |
+
# 3) parse config argument to args.
|
248 |
+
# 4) in the end, parse command line argument and overwrite args
|
249 |
+
|
250 |
+
# step 1: read cfg_fn
|
251 |
+
# args = parser.parse_args()
|
252 |
+
# Parse the arguments.
|
253 |
+
if parse:
|
254 |
+
args = parser.parse_args()
|
255 |
+
# For interative engironmnet (ex. jupyter)
|
256 |
+
else:
|
257 |
+
args = parser.parse_known_args()[0]
|
258 |
+
# print(args)
|
259 |
+
|
260 |
+
# Namespace => Dictionary
|
261 |
+
kwargs = vars(args)
|
262 |
+
# for k, v in optional_kwargs.items():
|
263 |
+
# setattr(args, k, v)
|
264 |
+
kwargs.update(optional_kwargs)
|
265 |
+
|
266 |
+
args = Config(**kwargs)
|
267 |
+
|
268 |
+
|
269 |
+
if args.cfg is not None or args.set_cfgs is not None:
|
270 |
+
from .config import CfgNode
|
271 |
+
if args.cfg is not None:
|
272 |
+
# print('Read Cfg')
|
273 |
+
cn = CfgNode(CfgNode.load_yaml_with_base(args.cfg))
|
274 |
+
# print(cn)
|
275 |
+
else:
|
276 |
+
cn = CfgNode()
|
277 |
+
if args.set_cfgs is not None:
|
278 |
+
cn.merge_from_list(args.set_cfgs)
|
279 |
+
for k,v in cn.items():
|
280 |
+
if not hasattr(args, k):
|
281 |
+
import os
|
282 |
+
if 'LOCAL_RANK' in os.environ and os.environ['LOCAL_RANK'] != '0':
|
283 |
+
pass
|
284 |
+
else:
|
285 |
+
print('Warning: key %s not in args' % k)
|
286 |
+
|
287 |
+
setattr(args, k, v)
|
288 |
+
|
289 |
+
if parse:
|
290 |
+
args = parser.parse_args(namespace=args)
|
291 |
+
else:
|
292 |
+
args = parser.parse_known_args(namespace=args)[0]
|
293 |
+
|
294 |
+
# Check if args are valid
|
295 |
+
assert args.rnn_size > 0, "rnn_size should be greater than 0"
|
296 |
+
assert args.num_layers > 0, "num_layers should be greater than 0"
|
297 |
+
assert args.input_encoding_size > 0, "input_encoding_size should be greater than 0"
|
298 |
+
assert args.batch_size > 0, "batch_size should be greater than 0"
|
299 |
+
assert args.drop_prob_lm >= 0 and args.drop_prob_lm < 1, "drop_prob_lm should be between 0 and 1"
|
300 |
+
assert args.seq_per_img > 0, "seq_per_img should be greater than 0"
|
301 |
+
assert args.beam_size > 0, "beam_size should be greater than 0"
|
302 |
+
assert args.save_checkpoint_every > 0, "save_checkpoint_every should be greater than 0"
|
303 |
+
assert args.losses_log_every > 0, "losses_log_every should be greater than 0"
|
304 |
+
assert args.language_eval == 0 or args.language_eval == 1, "language_eval should be 0 or 1"
|
305 |
+
assert args.load_best_score == 0 or args.load_best_score == 1, "language_eval should be 0 or 1"
|
306 |
+
assert args.train_only == 0 or args.train_only == 1, "language_eval should be 0 or 1"
|
307 |
+
|
308 |
+
# default value for start_from and checkpoint_path
|
309 |
+
args.checkpoint_path = args.checkpoint_path or './log_%s' %args.id
|
310 |
+
args.start_from = args.start_from or args.checkpoint_path
|
311 |
+
|
312 |
+
# Deal with feature things before anything
|
313 |
+
args.use_fc, args.use_att = if_use_feat(args.caption_model)
|
314 |
+
if args.use_box: args.att_feat_size = args.att_feat_size + 5
|
315 |
+
|
316 |
+
return args
|
317 |
+
|
318 |
+
|
319 |
+
def add_eval_options(parser):
|
320 |
+
# Basic options
|
321 |
+
parser.add_argument('--batch_size', type=int, default=0,
|
322 |
+
help='if > 0 then overrule, otherwise load from checkpoint.')
|
323 |
+
parser.add_argument('--num_images', type=int, default=-1,
|
324 |
+
help='how many images to use when periodically evaluating the loss? (-1 = all)')
|
325 |
+
parser.add_argument('--language_eval', type=int, default=0,
|
326 |
+
help='Evaluate language as well (1 = yes, 0 = no)? BLEU/CIDEr/METEOR/ROUGE_L? requires coco-caption code from Github.')
|
327 |
+
parser.add_argument('--dump_images', type=int, default=1,
|
328 |
+
help='Dump images into vis/imgs folder for vis? (1=yes,0=no)')
|
329 |
+
parser.add_argument('--dump_json', type=int, default=1,
|
330 |
+
help='Dump json with predictions into vis folder? (1=yes,0=no)')
|
331 |
+
parser.add_argument('--dump_path', type=int, default=0,
|
332 |
+
help='Write image paths along with predictions into vis json? (1=yes,0=no)')
|
333 |
+
|
334 |
+
# Sampling options
|
335 |
+
add_eval_sample_opts(parser)
|
336 |
+
|
337 |
+
# For evaluation on a folder of images:
|
338 |
+
parser.add_argument('--image_folder', type=str, default='',
|
339 |
+
help='If this is nonempty then will predict on the images in this folder path')
|
340 |
+
parser.add_argument('--image_root', type=str, default='',
|
341 |
+
help='In case the image paths have to be preprended with a root path to an image folder')
|
342 |
+
# For evaluation on MSCOCO images from some split:
|
343 |
+
parser.add_argument('--input_fc_dir', type=str, default='',
|
344 |
+
help='path to the h5file containing the preprocessed dataset')
|
345 |
+
parser.add_argument('--input_att_dir', type=str, default='',
|
346 |
+
help='path to the h5file containing the preprocessed dataset')
|
347 |
+
parser.add_argument('--input_box_dir', type=str, default='',
|
348 |
+
help='path to the h5file containing the preprocessed dataset')
|
349 |
+
parser.add_argument('--input_label_h5', type=str, default='',
|
350 |
+
help='path to the h5file containing the preprocessed dataset')
|
351 |
+
parser.add_argument('--input_json', type=str, default='',
|
352 |
+
help='path to the json file containing additional info and vocab. empty = fetch from model checkpoint.')
|
353 |
+
parser.add_argument('--split', type=str, default='test',
|
354 |
+
help='if running on MSCOCO images, which split to use: val|test|train')
|
355 |
+
parser.add_argument('--coco_json', type=str, default='',
|
356 |
+
help='if nonempty then use this file in DataLoaderRaw (see docs there). Used only in MSCOCO test evaluation, where we have a specific json file of only test set images.')
|
357 |
+
# misc
|
358 |
+
parser.add_argument('--id', type=str, default='',
|
359 |
+
help='an id identifying this run/job. used only if language_eval = 1 for appending to intermediate files')
|
360 |
+
parser.add_argument('--verbose_beam', type=int, default=1,
|
361 |
+
help='if we need to print out all beam search beams.')
|
362 |
+
parser.add_argument('--verbose_loss', type=int, default=0,
|
363 |
+
help='If calculate loss using ground truth during evaluation')
|
364 |
+
|
365 |
+
def add_diversity_opts(parser):
|
366 |
+
parser.add_argument('--sample_n', type=int, default=1,
|
367 |
+
help='Diverse sampling')
|
368 |
+
parser.add_argument('--sample_n_method', type=str, default='sample',
|
369 |
+
help='sample, bs, dbs, gumbel, topk, dgreedy, dsample, dtopk, dtopp')
|
370 |
+
parser.add_argument('--eval_oracle', type=int, default=1,
|
371 |
+
help='if we need to calculate loss.')
|
372 |
+
|
373 |
+
|
374 |
+
# Sampling related options
|
375 |
+
def add_eval_sample_opts(parser):
|
376 |
+
parser.add_argument('--sample_method', type=str, default='greedy',
|
377 |
+
help='greedy; sample; gumbel; top<int>, top<0-1>')
|
378 |
+
parser.add_argument('--beam_size', type=int, default=1,
|
379 |
+
help='used when sample_method = greedy, indicates number of beams in beam search. Usually 2 or 3 works well. More is not better. Set this to 1 for faster runtime but a bit worse performance.')
|
380 |
+
parser.add_argument('--max_length', type=int, default=20,
|
381 |
+
help='Maximum length during sampling')
|
382 |
+
parser.add_argument('--length_penalty', type=str, default='',
|
383 |
+
help='wu_X or avg_X, X is the alpha')
|
384 |
+
parser.add_argument('--group_size', type=int, default=1,
|
385 |
+
help='used for diverse beam search. if group_size is 1, then it\'s normal beam search')
|
386 |
+
parser.add_argument('--diversity_lambda', type=float, default=0.5,
|
387 |
+
help='used for diverse beam search. Usually from 0.2 to 0.8. Higher value of lambda produces a more diverse list')
|
388 |
+
parser.add_argument('--temperature', type=float, default=1.0,
|
389 |
+
help='temperature when sampling from distributions (i.e. when sample_method = sample). Lower = "safer" predictions.')
|
390 |
+
parser.add_argument('--decoding_constraint', type=int, default=0,
|
391 |
+
help='If 1, not allowing same word in a row')
|
392 |
+
parser.add_argument('--block_trigrams', type=int, default=0,
|
393 |
+
help='block repeated trigram.')
|
394 |
+
parser.add_argument('--remove_bad_endings', type=int, default=0,
|
395 |
+
help='Remove bad endings')
|
396 |
+
parser.add_argument('--suppress_UNK', type=int, default=1,
|
397 |
+
help='Not predicting UNK')
|
398 |
+
|
399 |
+
|
400 |
+
if __name__ == '__main__':
|
401 |
+
import sys
|
402 |
+
sys.argv = [sys.argv[0]]
|
403 |
+
args = parse_opt()
|
404 |
+
print(args)
|
405 |
+
print()
|
406 |
+
sys.argv = [sys.argv[0], '--cfg', 'configs/updown_long.yml']
|
407 |
+
args1 = parse_opt()
|
408 |
+
print(dict(set(vars(args1).items()) - set(vars(args).items())))
|
409 |
+
print()
|
410 |
+
sys.argv = [sys.argv[0], '--cfg', 'configs/updown_long.yml', '--caption_model', 'att2in2']
|
411 |
+
args2 = parse_opt()
|
412 |
+
print(dict(set(vars(args2).items()) - set(vars(args1).items())))
|
captioning/utils/resnet.py
ADDED
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torchvision.models.resnet
|
4 |
+
from torchvision.models.resnet import BasicBlock, Bottleneck
|
5 |
+
|
6 |
+
class ResNet(torchvision.models.resnet.ResNet):
|
7 |
+
def __init__(self, block, layers, num_classes=1000):
|
8 |
+
super(ResNet, self).__init__(block, layers, num_classes)
|
9 |
+
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=0, ceil_mode=True) # change
|
10 |
+
for i in range(2, 5):
|
11 |
+
getattr(self, 'layer%d'%i)[0].conv1.stride = (2,2)
|
12 |
+
getattr(self, 'layer%d'%i)[0].conv2.stride = (1,1)
|
13 |
+
|
14 |
+
def resnet18(pretrained=False):
|
15 |
+
"""Constructs a ResNet-18 model.
|
16 |
+
|
17 |
+
Args:
|
18 |
+
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
19 |
+
"""
|
20 |
+
model = ResNet(BasicBlock, [2, 2, 2, 2])
|
21 |
+
if pretrained:
|
22 |
+
model.load_state_dict(model_zoo.load_url(model_urls['resnet18']))
|
23 |
+
return model
|
24 |
+
|
25 |
+
|
26 |
+
def resnet34(pretrained=False):
|
27 |
+
"""Constructs a ResNet-34 model.
|
28 |
+
|
29 |
+
Args:
|
30 |
+
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
31 |
+
"""
|
32 |
+
model = ResNet(BasicBlock, [3, 4, 6, 3])
|
33 |
+
if pretrained:
|
34 |
+
model.load_state_dict(model_zoo.load_url(model_urls['resnet34']))
|
35 |
+
return model
|
36 |
+
|
37 |
+
|
38 |
+
def resnet50(pretrained=False):
|
39 |
+
"""Constructs a ResNet-50 model.
|
40 |
+
|
41 |
+
Args:
|
42 |
+
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
43 |
+
"""
|
44 |
+
model = ResNet(Bottleneck, [3, 4, 6, 3])
|
45 |
+
if pretrained:
|
46 |
+
model.load_state_dict(model_zoo.load_url(model_urls['resnet50']))
|
47 |
+
return model
|
48 |
+
|
49 |
+
|
50 |
+
def resnet101(pretrained=False):
|
51 |
+
"""Constructs a ResNet-101 model.
|
52 |
+
|
53 |
+
Args:
|
54 |
+
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
55 |
+
"""
|
56 |
+
model = ResNet(Bottleneck, [3, 4, 23, 3])
|
57 |
+
if pretrained:
|
58 |
+
model.load_state_dict(model_zoo.load_url(model_urls['resnet101']))
|
59 |
+
return model
|
60 |
+
|
61 |
+
|
62 |
+
def resnet152(pretrained=False):
|
63 |
+
"""Constructs a ResNet-152 model.
|
64 |
+
|
65 |
+
Args:
|
66 |
+
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
67 |
+
"""
|
68 |
+
model = ResNet(Bottleneck, [3, 8, 36, 3])
|
69 |
+
if pretrained:
|
70 |
+
model.load_state_dict(model_zoo.load_url(model_urls['resnet152']))
|
71 |
+
return model
|
captioning/utils/resnet_utils.py
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
|
5 |
+
class myResnet(nn.Module):
|
6 |
+
def __init__(self, resnet):
|
7 |
+
super(myResnet, self).__init__()
|
8 |
+
self.resnet = resnet
|
9 |
+
|
10 |
+
def forward(self, img, att_size=14):
|
11 |
+
x = img.unsqueeze(0)
|
12 |
+
|
13 |
+
x = self.resnet.conv1(x)
|
14 |
+
x = self.resnet.bn1(x)
|
15 |
+
x = self.resnet.relu(x)
|
16 |
+
x = self.resnet.maxpool(x)
|
17 |
+
|
18 |
+
x = self.resnet.layer1(x)
|
19 |
+
x = self.resnet.layer2(x)
|
20 |
+
x = self.resnet.layer3(x)
|
21 |
+
x = self.resnet.layer4(x)
|
22 |
+
|
23 |
+
fc = x.mean(3).mean(2).squeeze()
|
24 |
+
att = F.adaptive_avg_pool2d(x,[att_size,att_size]).squeeze().permute(1, 2, 0)
|
25 |
+
|
26 |
+
return fc, att
|
27 |
+
|
captioning/utils/rewards.py
ADDED
@@ -0,0 +1,392 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import absolute_import
|
2 |
+
from __future__ import division
|
3 |
+
from __future__ import print_function
|
4 |
+
|
5 |
+
import numpy as np
|
6 |
+
import time
|
7 |
+
from collections import OrderedDict
|
8 |
+
import torch
|
9 |
+
|
10 |
+
import sys
|
11 |
+
try:
|
12 |
+
sys.path.append("cider")
|
13 |
+
from pyciderevalcap.ciderD.ciderD import CiderD
|
14 |
+
from pyciderevalcap.cider.cider import Cider
|
15 |
+
sys.path.append("coco-caption")
|
16 |
+
from pycocoevalcap.bleu.bleu import Bleu
|
17 |
+
except:
|
18 |
+
print('cider or coco-caption missing')
|
19 |
+
|
20 |
+
CiderD_scorer = None
|
21 |
+
Cider_scorer = None
|
22 |
+
Bleu_scorer = None
|
23 |
+
#CiderD_scorer = CiderD(df='corpus')
|
24 |
+
|
25 |
+
|
26 |
+
from .misc import decode_sequence
|
27 |
+
|
28 |
+
def init_scorer(cached_tokens):
|
29 |
+
global CiderD_scorer
|
30 |
+
CiderD_scorer = CiderD_scorer or CiderD(df=cached_tokens)
|
31 |
+
global Cider_scorer
|
32 |
+
Cider_scorer = Cider_scorer or Cider(df=cached_tokens)
|
33 |
+
global Bleu_scorer
|
34 |
+
Bleu_scorer = Bleu_scorer or Bleu(4)
|
35 |
+
|
36 |
+
def array_to_str(arr):
|
37 |
+
out = ''
|
38 |
+
for i in range(len(arr)):
|
39 |
+
out += str(arr[i]) + ' '
|
40 |
+
if arr[i] == 0:
|
41 |
+
break
|
42 |
+
return out.strip()
|
43 |
+
|
44 |
+
def get_self_critical_reward(greedy_res, data_gts, gen_result, opt):
|
45 |
+
batch_size = len(data_gts)
|
46 |
+
gen_result_size = gen_result.shape[0]
|
47 |
+
seq_per_img = gen_result_size // len(data_gts) # gen_result_size = batch_size * seq_per_img
|
48 |
+
assert greedy_res.shape[0] == batch_size
|
49 |
+
|
50 |
+
res = OrderedDict()
|
51 |
+
gen_result = gen_result.data.cpu().numpy()
|
52 |
+
greedy_res = greedy_res.data.cpu().numpy()
|
53 |
+
for i in range(gen_result_size):
|
54 |
+
res[i] = [array_to_str(gen_result[i])]
|
55 |
+
for i in range(batch_size):
|
56 |
+
res[gen_result_size + i] = [array_to_str(greedy_res[i])]
|
57 |
+
|
58 |
+
gts = OrderedDict()
|
59 |
+
for i in range(len(data_gts)):
|
60 |
+
gts[i] = [array_to_str(data_gts[i][j]) for j in range(len(data_gts[i]))]
|
61 |
+
|
62 |
+
res_ = [{'image_id':i, 'caption': res[i]} for i in range(len(res))]
|
63 |
+
res__ = {i: res[i] for i in range(len(res_))}
|
64 |
+
gts_ = {i: gts[i // seq_per_img] for i in range(gen_result_size)}
|
65 |
+
gts_.update({i+gen_result_size: gts[i] for i in range(batch_size)})
|
66 |
+
if opt.cider_reward_weight > 0:
|
67 |
+
_, cider_scores = CiderD_scorer.compute_score(gts_, res_)
|
68 |
+
if hasattr(opt, 'verbose') and not opt.verbose:
|
69 |
+
pass
|
70 |
+
else:
|
71 |
+
print('Cider scores:', _)
|
72 |
+
else:
|
73 |
+
cider_scores = 0
|
74 |
+
if opt.bleu_reward_weight > 0:
|
75 |
+
_, bleu_scores = Bleu_scorer.compute_score(gts_, res__)
|
76 |
+
bleu_scores = np.array(bleu_scores[3])
|
77 |
+
if hasattr(opt, 'verbose') and not opt.verbose:
|
78 |
+
pass
|
79 |
+
else:
|
80 |
+
print('Bleu scores:', _[3])
|
81 |
+
else:
|
82 |
+
bleu_scores = 0
|
83 |
+
scores = opt.cider_reward_weight * cider_scores + opt.bleu_reward_weight * bleu_scores
|
84 |
+
|
85 |
+
unnormalized_reward_mean = scores[:gen_result_size].flatten().mean()
|
86 |
+
|
87 |
+
scores = scores[:gen_result_size].reshape(batch_size, seq_per_img) - scores[-batch_size:][:, np.newaxis]
|
88 |
+
|
89 |
+
scores = scores.reshape(gen_result_size)
|
90 |
+
|
91 |
+
rewards = np.repeat(scores[:, np.newaxis], gen_result.shape[1], 1)
|
92 |
+
|
93 |
+
return rewards, unnormalized_reward_mean
|
94 |
+
|
95 |
+
|
96 |
+
def get_self_critical_clipscore_reward(greedy_res, data_gts, gen_result, opt, clipscore_model, clip_vis_feats, vocab):
|
97 |
+
batch_size = len(data_gts)
|
98 |
+
gen_result_size = gen_result.shape[0]
|
99 |
+
seq_per_img = gen_result_size // len(data_gts) # gen_result_size = batch_size * seq_per_img
|
100 |
+
assert greedy_res.shape[0] == batch_size
|
101 |
+
|
102 |
+
B = batch_size
|
103 |
+
K = seq_per_img
|
104 |
+
L = gen_result.shape[1]
|
105 |
+
assert gen_result.shape == (B*K , L)
|
106 |
+
|
107 |
+
# res = OrderedDict()
|
108 |
+
# gen_result = gen_result.data.cpu().numpy()
|
109 |
+
# greedy_res = greedy_res.data.cpu().numpy()
|
110 |
+
# for i in range(gen_result_size):
|
111 |
+
# res[i] = [array_to_str(gen_result[i])]
|
112 |
+
# for i in range(batch_size):
|
113 |
+
# res[gen_result_size + i] = [array_to_str(greedy_res[i])]
|
114 |
+
|
115 |
+
# gts = OrderedDict()
|
116 |
+
# for i in range(len(data_gts)):
|
117 |
+
# gts[i] = [array_to_str(data_gts[i][j]) for j in range(len(data_gts[i]))]
|
118 |
+
|
119 |
+
# res_ = [{'image_id':i, 'caption': res[i]} for i in range(len(res))]
|
120 |
+
# res__ = {i: res[i] for i in range(len(res_))}
|
121 |
+
# gts_ = {i: gts[i // seq_per_img] for i in range(gen_result_size)}
|
122 |
+
# gts_.update({i+gen_result_size: gts[i] for i in range(batch_size)})
|
123 |
+
|
124 |
+
# res = []
|
125 |
+
# gen_result = gen_result.data.cpu().numpy()
|
126 |
+
# greedy_res = greedy_res.data.cpu().numpy()
|
127 |
+
# # for i in range(gen_result_size):
|
128 |
+
# # res.append(array_to_str(gen_result[i]))
|
129 |
+
# res.extend(decode_sequence(vocab, gen_result))
|
130 |
+
|
131 |
+
|
132 |
+
# # for i in range(batch_size):
|
133 |
+
# # res.append(array_to_str(greedy_res[i]))
|
134 |
+
# res.extend(decode_sequence(vocab, greedy_res))
|
135 |
+
|
136 |
+
if clipscore_model.mode == 'refclip_s':
|
137 |
+
gts = []
|
138 |
+
gts_valid_mask = []
|
139 |
+
max_n_refs = max([len(_gts) for _gts in data_gts])
|
140 |
+
for i in range(len(data_gts)):
|
141 |
+
_gts = decode_sequence(vocab, data_gts[i])
|
142 |
+
# pad references
|
143 |
+
n_ref = len(_gts)
|
144 |
+
_gts.extend([''] * (max_n_refs - n_ref))
|
145 |
+
gts.extend(_gts)
|
146 |
+
gts_valid_mask.extend([1] * n_ref + [0] * (max_n_refs - n_ref))
|
147 |
+
assert len(gts) == B * max_n_refs
|
148 |
+
assert len(gts_valid_mask) == B * max_n_refs
|
149 |
+
|
150 |
+
# print(gts)
|
151 |
+
# print(gts_valid_mask)
|
152 |
+
# exit()
|
153 |
+
|
154 |
+
|
155 |
+
# assert len(res) == B * K + B, len(res)
|
156 |
+
|
157 |
+
# print(res)
|
158 |
+
# exit()
|
159 |
+
|
160 |
+
if opt.clipscore_reward_weight > 0:
|
161 |
+
with torch.no_grad():
|
162 |
+
clipscore_model.eval()
|
163 |
+
|
164 |
+
# 1) calculate reward
|
165 |
+
gen_result = gen_result.data.cpu().numpy()
|
166 |
+
res = decode_sequence(vocab, gen_result)
|
167 |
+
assert len(res) == B * K, len(res)
|
168 |
+
|
169 |
+
# [B * K, dim)
|
170 |
+
if getattr(opt, 'use_grammar', False) and not getattr(opt, 'joint_out', False):
|
171 |
+
text_pre_feat = clipscore_model.text_extract(res, proj_norm=False)
|
172 |
+
|
173 |
+
grammar_logit = clipscore_model.grammar_score_head(text_pre_feat.view(-1, 512))
|
174 |
+
grammar_prob = torch.softmax(grammar_logit, dim=-1)[:, 1]
|
175 |
+
grammar_prob = grammar_prob.view(B*K).detach()
|
176 |
+
|
177 |
+
text_feat = clipscore_model.clip_model.text_projection(text_pre_feat)
|
178 |
+
text_feat = text_feat / text_feat.norm(dim=-1, keepdim=True)
|
179 |
+
|
180 |
+
else:
|
181 |
+
text_feat = clipscore_model.text_extract(res)
|
182 |
+
|
183 |
+
|
184 |
+
assert text_feat.size() == (B * K, 512), text_feat.size()
|
185 |
+
assert clip_vis_feats.size() == (B, 512), clip_vis_feats.size()
|
186 |
+
|
187 |
+
# [B * K, dim]
|
188 |
+
vis_feat = clip_vis_feats.view(B, 1, -1).expand(-1, K, -1).contiguous().view(B * K, -1)
|
189 |
+
|
190 |
+
clip_s = clipscore_model(text_feat=text_feat, img_feat=vis_feat, mode='clip_s')
|
191 |
+
clip_s = clip_s.view(B * K).detach()
|
192 |
+
|
193 |
+
if clipscore_model.mode == 'refclip_s':
|
194 |
+
# [B * n_ref, dim]
|
195 |
+
ref_text_feat = clipscore_model.text_extract(gts)
|
196 |
+
ref_text_mask = torch.tensor(gts_valid_mask, dtype=ref_text_feat.dtype, device=ref_text_feat.device)
|
197 |
+
|
198 |
+
assert ref_text_feat.size() == (B * max_n_refs, 512), ref_text_feat.size()
|
199 |
+
assert ref_text_mask.size() == (B * max_n_refs,), ref_text_mask.size()
|
200 |
+
|
201 |
+
# [B * K]
|
202 |
+
refclip_s = clipscore_model.calc_refclip_s(
|
203 |
+
text_feat=text_feat, img_feat=vis_feat,
|
204 |
+
ref_text_feat=ref_text_feat.view(B, 1, max_n_refs, -1).expand(-1, K, -1, -1).contiguous().view(B * K * max_n_refs, -1),
|
205 |
+
ref_text_mask=ref_text_mask.view(B, 1, max_n_refs).expand(-1, K, -1).contiguous().view(B * K * max_n_refs),
|
206 |
+
clip_s=clip_s)
|
207 |
+
refclip_s = refclip_s.view(B * K).detach()
|
208 |
+
|
209 |
+
# 2) calcualte reward for baseline (greedy)
|
210 |
+
greedy_res = greedy_res.data.cpu().numpy()
|
211 |
+
res = decode_sequence(vocab, greedy_res)
|
212 |
+
assert len(res) == B, len(res)
|
213 |
+
|
214 |
+
# [B, dim)
|
215 |
+
|
216 |
+
if getattr(opt, 'use_grammar', False) and getattr(opt, 'use_grammar_baseline', False) and not getattr(opt, 'joint_out', False):
|
217 |
+
text_pre_feat = clipscore_model.text_extract(res, proj_norm=False)
|
218 |
+
|
219 |
+
grammar_logit = clipscore_model.grammar_score_head(text_pre_feat.view(-1, 512))
|
220 |
+
grammar_prob_baseline = torch.softmax(grammar_logit, dim=-1)[:, 1]
|
221 |
+
grammar_prob_baseline = grammar_prob_baseline.view(B).detach()
|
222 |
+
|
223 |
+
text_feat = clipscore_model.clip_model.text_projection(text_pre_feat)
|
224 |
+
text_feat = text_feat / text_feat.norm(dim=-1, keepdim=True)
|
225 |
+
else:
|
226 |
+
text_feat = clipscore_model.text_extract(res)
|
227 |
+
|
228 |
+
assert text_feat.size() == (B, 512), text_feat.size()
|
229 |
+
assert clip_vis_feats.size() == (B, 512), clip_vis_feats.size()
|
230 |
+
|
231 |
+
vis_feat = clip_vis_feats.view(B, 512)
|
232 |
+
|
233 |
+
# [B]
|
234 |
+
clip_s_baseline = clipscore_model(text_feat=text_feat, img_feat=vis_feat, mode='clip_s')
|
235 |
+
clip_s_baseline = clip_s_baseline.view(B).detach()
|
236 |
+
|
237 |
+
if clipscore_model.mode == 'refclip_s':
|
238 |
+
# # [B * n_ref]
|
239 |
+
# ref_text_feat = clipscore_model.text_extract(gts)
|
240 |
+
# ref_text_mask = torch.tensor(gts_valid_mask, dtype=ref_text_feat.dtype, device=ref_text_feat.device)
|
241 |
+
# assert ref_text_feat.size() == (B * max_n_refs, 512), ref_text_feat.size()
|
242 |
+
# assert ref_text_mask.size() == (B * max_n_refs), ref_text_mask.size()
|
243 |
+
|
244 |
+
# [B]
|
245 |
+
refclip_s_baseline = clipscore_model.calc_refclip_s(
|
246 |
+
text_feat=text_feat, img_feat=vis_feat,
|
247 |
+
ref_text_feat=ref_text_feat,
|
248 |
+
ref_text_mask=ref_text_mask,
|
249 |
+
clip_s=clip_s_baseline)
|
250 |
+
refclip_s_baseline = refclip_s_baseline.view(B).detach()
|
251 |
+
|
252 |
+
if clipscore_model.mode == 'clip_s':
|
253 |
+
rewards = clip_s - clip_s_baseline.view(B, 1).expand(-1, K).contiguous().flatten()
|
254 |
+
unnormalized_mean_reward = clip_s.mean()
|
255 |
+
elif clipscore_model.mode == 'refclip_s':
|
256 |
+
rewards = refclip_s - refclip_s_baseline.view(B, 1).expand(-1, K).contiguous().flatten()
|
257 |
+
unnormalized_mean_reward = refclip_s.mean()
|
258 |
+
|
259 |
+
# # [B * K + B, dim)
|
260 |
+
# text_feat = clipscore_model.text_extract(res)
|
261 |
+
# assert text_feat.size() == (B * K + B, 512), text_feat.size()
|
262 |
+
|
263 |
+
# assert clip_vis_feats.size() == (B, 512), clip_vis_feats.size()
|
264 |
+
|
265 |
+
# # [B, dim] -> [B * K + B, dim]
|
266 |
+
# # vis_feat = clip_vis_feats.view(B, 1, -1).expand(-1, K + 1, -1).contiguous().view(B * (K + 1), -1)
|
267 |
+
# # vis_feat = clip_vis_feats.view(1, B, -1).expand(K + 1, -1, -1).contiguous().view((K + 1) * B, -1)
|
268 |
+
|
269 |
+
# # [B * K, dim]
|
270 |
+
# gen_vis_feat = clip_vis_feats.view(B, 1, -1).expand(-1, K, -1).contiguous().view(B * K, -1)
|
271 |
+
# # [B, dim]
|
272 |
+
# greedy_vis_feat = clip_vis_feats
|
273 |
+
# # [B * K + B, dim]
|
274 |
+
# vis_feat = torch.cat([gen_vis_feat, greedy_vis_feat], dim=0)
|
275 |
+
|
276 |
+
# # if clipscore_model.mode == 'clip_s':
|
277 |
+
# # [B * K + B, dim]
|
278 |
+
# clip_s = clipscore_model(text_feat=text_feat, img_feat=vis_feat)
|
279 |
+
# clip_s = clip_s.view(B * K + B).detach()
|
280 |
+
|
281 |
+
|
282 |
+
# if clipscore_model.mode == 'refclip_s':
|
283 |
+
# # [B * K, dim]
|
284 |
+
# ref_text_feat = clipscore_model.text_extract(gts)
|
285 |
+
|
286 |
+
# clipscore_scores = clipscore_model.calc_refclip_s(text_feat=text_feat, img_feat=vis_feat, ref_text_feat=ref_text_feat, clip_s=clip_s)
|
287 |
+
# clipscore_scores = clipscore_scores.view(B * K + B).detach()
|
288 |
+
|
289 |
+
if getattr(opt, 'use_grammar', False) and not getattr(opt, 'joint_out', False):
|
290 |
+
|
291 |
+
if getattr(opt, 'use_grammar_baseline', False):
|
292 |
+
grammar_rewards = grammar_prob - grammar_prob_baseline.view(B, 1).expand(-1, K).contiguous().flatten()
|
293 |
+
else:
|
294 |
+
grammar_rewards = grammar_prob
|
295 |
+
else:
|
296 |
+
grammar_rewards = None
|
297 |
+
|
298 |
+
|
299 |
+
if hasattr(opt, 'verbose') and not opt.verbose:
|
300 |
+
pass
|
301 |
+
else:
|
302 |
+
if clipscore_model.mode == 'clip_s':
|
303 |
+
print('CLIP-S:', rewards)
|
304 |
+
elif clipscore_model.mode == 'refclip_s':
|
305 |
+
print('RefCLIP-S:', rewards)
|
306 |
+
else:
|
307 |
+
rewards = torch.zeros(B, L)
|
308 |
+
unnormalized_mean_reward = None
|
309 |
+
grammar_rewards = None
|
310 |
+
|
311 |
+
|
312 |
+
rewards = opt.clipscore_reward_weight * rewards
|
313 |
+
|
314 |
+
|
315 |
+
# scores = scores[:gen_result_size].reshape(batch_size, seq_per_img) - scores[-batch_size:][:, np.newaxis]
|
316 |
+
# scores = scores.reshape(gen_result_size)
|
317 |
+
# rewards = np.repeat(scores[:, np.newaxis], gen_result.shape[1], 1)
|
318 |
+
|
319 |
+
# [B, K]
|
320 |
+
# scores = scores[:gen_result_size].reshape(B, K) - scores[-B:].unsqueeze(1)
|
321 |
+
|
322 |
+
# [B*K, L]
|
323 |
+
# rewards = scores.view(-1, 1).expand(-1, L).contiguous()
|
324 |
+
rewards = rewards.view(-1, 1).expand(-1, L).contiguous()
|
325 |
+
|
326 |
+
if getattr(opt, 'use_grammar', False) and not getattr(opt, 'joint_out', False):
|
327 |
+
grammar_rewards = grammar_rewards.view(-1, 1).expand(-1, L).contiguous()
|
328 |
+
|
329 |
+
return rewards, unnormalized_mean_reward, grammar_rewards
|
330 |
+
|
331 |
+
def get_scores(data_gts, gen_result, opt):
|
332 |
+
batch_size = gen_result.size(0)# batch_size = sample_size * seq_per_img
|
333 |
+
seq_per_img = batch_size // len(data_gts)
|
334 |
+
|
335 |
+
res = OrderedDict()
|
336 |
+
|
337 |
+
gen_result = gen_result.data.cpu().numpy()
|
338 |
+
for i in range(batch_size):
|
339 |
+
res[i] = [array_to_str(gen_result[i])]
|
340 |
+
|
341 |
+
gts = OrderedDict()
|
342 |
+
for i in range(len(data_gts)):
|
343 |
+
gts[i] = [array_to_str(data_gts[i][j]) for j in range(len(data_gts[i]))]
|
344 |
+
|
345 |
+
res_ = [{'image_id':i, 'caption': res[i]} for i in range(batch_size)]
|
346 |
+
res__ = {i: res[i] for i in range(batch_size)}
|
347 |
+
gts = {i: gts[i // seq_per_img] for i in range(batch_size)}
|
348 |
+
if opt.cider_reward_weight > 0:
|
349 |
+
_, cider_scores = CiderD_scorer.compute_score(gts, res_)
|
350 |
+
# print('Cider scores:', _)
|
351 |
+
if hasattr(opt, 'verbose') and not opt.verbose:
|
352 |
+
pass
|
353 |
+
else:
|
354 |
+
print('Cider scores:', _)
|
355 |
+
else:
|
356 |
+
cider_scores = 0
|
357 |
+
if opt.bleu_reward_weight > 0:
|
358 |
+
_, bleu_scores = Bleu_scorer.compute_score(gts, res__)
|
359 |
+
bleu_scores = np.array(bleu_scores[3])
|
360 |
+
# print('Bleu scores:', _[3])
|
361 |
+
if hasattr(opt, 'verbose') and not opt.verbose:
|
362 |
+
pass
|
363 |
+
else:
|
364 |
+
print('Bleu scores:', _[3])
|
365 |
+
else:
|
366 |
+
bleu_scores = 0
|
367 |
+
|
368 |
+
scores = opt.cider_reward_weight * cider_scores + opt.bleu_reward_weight * bleu_scores
|
369 |
+
|
370 |
+
return scores
|
371 |
+
|
372 |
+
def get_self_cider_scores(data_gts, gen_result, opt):
|
373 |
+
batch_size = gen_result.size(0)# batch_size = sample_size * seq_per_img
|
374 |
+
seq_per_img = batch_size // len(data_gts)
|
375 |
+
|
376 |
+
res = []
|
377 |
+
|
378 |
+
gen_result = gen_result.data.cpu().numpy()
|
379 |
+
for i in range(batch_size):
|
380 |
+
res.append(array_to_str(gen_result[i]))
|
381 |
+
|
382 |
+
scores = []
|
383 |
+
for i in range(len(data_gts)):
|
384 |
+
tmp = Cider_scorer.my_self_cider([res[i*seq_per_img:(i+1)*seq_per_img]])
|
385 |
+
def get_div(eigvals):
|
386 |
+
eigvals = np.clip(eigvals, 0, None)
|
387 |
+
return -np.log(np.sqrt(eigvals[-1]) / (np.sqrt(eigvals).sum())) / np.log(len(eigvals))
|
388 |
+
scores.append(get_div(np.linalg.eigvalsh(tmp[0]/10)))
|
389 |
+
|
390 |
+
scores = np.array(scores)
|
391 |
+
|
392 |
+
return scores
|
captioning/utils/utils.py
ADDED
@@ -0,0 +1,138 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import re
|
2 |
+
import numpy as np
|
3 |
+
import torch
|
4 |
+
import torch.distributed as dist
|
5 |
+
import collections
|
6 |
+
import logging
|
7 |
+
|
8 |
+
def get_area(pos):
|
9 |
+
"""
|
10 |
+
Args
|
11 |
+
pos: [B, N, 4]
|
12 |
+
(x1, x2, y1, y2)
|
13 |
+
|
14 |
+
Return
|
15 |
+
area : [B, N]
|
16 |
+
"""
|
17 |
+
# [B, N]
|
18 |
+
height = pos[:, :, 3] - pos[:, :, 2]
|
19 |
+
width = pos[:, :, 1] - pos[:, :, 0]
|
20 |
+
area = height * width
|
21 |
+
return area
|
22 |
+
|
23 |
+
def get_relative_distance(pos):
|
24 |
+
"""
|
25 |
+
Args
|
26 |
+
pos: [B, N, 4]
|
27 |
+
(x1, x2, y1, y2)
|
28 |
+
|
29 |
+
Return
|
30 |
+
out : [B, N, N, 4]
|
31 |
+
"""
|
32 |
+
# B, N = pos.size()[:-1]
|
33 |
+
|
34 |
+
# [B, N, N, 4]
|
35 |
+
relative_distance = pos.unsqueeze(1) - pos.unsqueeze(2)
|
36 |
+
|
37 |
+
return relative_distance
|
38 |
+
|
39 |
+
|
40 |
+
class LossMeter(object):
|
41 |
+
def __init__(self, maxlen=100):
|
42 |
+
"""Computes and stores the running average"""
|
43 |
+
self.vals = collections.deque([], maxlen=maxlen)
|
44 |
+
|
45 |
+
def __len__(self):
|
46 |
+
return len(self.vals)
|
47 |
+
|
48 |
+
def update(self, new_val):
|
49 |
+
self.vals.append(new_val)
|
50 |
+
|
51 |
+
@property
|
52 |
+
def val(self):
|
53 |
+
return sum(self.vals) / len(self.vals)
|
54 |
+
|
55 |
+
def __repr__(self):
|
56 |
+
return str(self.val)
|
57 |
+
|
58 |
+
|
59 |
+
def count_parameters(model):
|
60 |
+
return sum(p.numel() for p in model.parameters() if p.requires_grad)
|
61 |
+
|
62 |
+
|
63 |
+
def load_state_dict(state_dict_path, loc='cpu'):
|
64 |
+
state_dict = torch.load(state_dict_path, map_location=loc)
|
65 |
+
# Change Multi GPU to single GPU
|
66 |
+
original_keys = list(state_dict.keys())
|
67 |
+
for key in original_keys:
|
68 |
+
if key.startswith("module."):
|
69 |
+
new_key = key[len("module."):]
|
70 |
+
state_dict[new_key] = state_dict.pop(key)
|
71 |
+
return state_dict
|
72 |
+
|
73 |
+
|
74 |
+
def set_global_logging_level(level=logging.ERROR, prefices=[""]):
|
75 |
+
"""
|
76 |
+
Override logging levels of different modules based on their name as a prefix.
|
77 |
+
It needs to be invoked after the modules have been loaded so that their loggers have been initialized.
|
78 |
+
|
79 |
+
Args:
|
80 |
+
- level: desired level. e.g. logging.INFO. Optional. Default is logging.ERROR
|
81 |
+
- prefices: list of one or more str prefices to match (e.g. ["transformers", "torch"]). Optional.
|
82 |
+
Default is `[""]` to match all active loggers.
|
83 |
+
The match is a case-sensitive `module_name.startswith(prefix)`
|
84 |
+
"""
|
85 |
+
prefix_re = re.compile(fr'^(?:{ "|".join(prefices) })')
|
86 |
+
for name in logging.root.manager.loggerDict:
|
87 |
+
if re.match(prefix_re, name):
|
88 |
+
logging.getLogger(name).setLevel(level)
|
89 |
+
|
90 |
+
|
91 |
+
def get_iou(anchors, gt_boxes):
|
92 |
+
"""
|
93 |
+
anchors: (N, 4) torch floattensor
|
94 |
+
gt_boxes: (K, 4) torch floattensor
|
95 |
+
overlaps: (N, K) ndarray of overlap between boxes and query_boxes
|
96 |
+
"""
|
97 |
+
N = anchors.size(0)
|
98 |
+
|
99 |
+
if gt_boxes.size() == (4,):
|
100 |
+
gt_boxes = gt_boxes.view(1, 4)
|
101 |
+
K = gt_boxes.size(0)
|
102 |
+
|
103 |
+
gt_boxes_area = (
|
104 |
+
(gt_boxes[:, 2] - gt_boxes[:, 0] + 1) *
|
105 |
+
(gt_boxes[:, 3] - gt_boxes[:, 1] + 1)
|
106 |
+
).view(1, K)
|
107 |
+
|
108 |
+
anchors_area = (
|
109 |
+
(anchors[:, 2] - anchors[:, 0] + 1) *
|
110 |
+
(anchors[:, 3] - anchors[:, 1] + 1)
|
111 |
+
).view(N, 1)
|
112 |
+
|
113 |
+
boxes = anchors.view(N, 1, 4).expand(N, K, 4)
|
114 |
+
query_boxes = gt_boxes.view(1, K, 4).expand(N, K, 4)
|
115 |
+
|
116 |
+
iw = (
|
117 |
+
torch.min(boxes[:, :, 2], query_boxes[:, :, 2])
|
118 |
+
- torch.max(boxes[:, :, 0], query_boxes[:, :, 0])
|
119 |
+
+ 1
|
120 |
+
)
|
121 |
+
iw[iw < 0] = 0
|
122 |
+
|
123 |
+
ih = (
|
124 |
+
torch.min(boxes[:, :, 3], query_boxes[:, :, 3])
|
125 |
+
- torch.max(boxes[:, :, 1], query_boxes[:, :, 1])
|
126 |
+
+ 1
|
127 |
+
)
|
128 |
+
ih[ih < 0] = 0
|
129 |
+
|
130 |
+
ua = anchors_area + gt_boxes_area - (iw * ih)
|
131 |
+
overlaps = iw * ih / ua
|
132 |
+
|
133 |
+
return overlaps
|
134 |
+
|
135 |
+
|
136 |
+
def xywh_to_xyxy(boxes):
|
137 |
+
"""Convert [x y w h] box format to [x1 y1 x2 y2] format."""
|
138 |
+
return np.hstack((boxes[:, 0:2], boxes[:, 0:2] + boxes[:, 2:4] - 1))
|
clip/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .clip import *
|
clip/bpe_simple_vocab_16e6.txt.gz
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:924691ac288e54409236115652ad4aa250f48203de50a9e4722a6ecd48d6804a
|
3 |
+
size 1356917
|
clip/clip.py
ADDED
@@ -0,0 +1,193 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import hashlib
|
2 |
+
import os
|
3 |
+
import urllib
|
4 |
+
import warnings
|
5 |
+
from typing import Union, List
|
6 |
+
|
7 |
+
import torch
|
8 |
+
from PIL import Image
|
9 |
+
from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
|
10 |
+
from tqdm import tqdm
|
11 |
+
|
12 |
+
from .model import build_model
|
13 |
+
from .simple_tokenizer import SimpleTokenizer as _Tokenizer
|
14 |
+
|
15 |
+
__all__ = ["available_models", "load", "tokenize"]
|
16 |
+
_tokenizer = _Tokenizer()
|
17 |
+
|
18 |
+
_MODELS = {
|
19 |
+
"RN50": "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt",
|
20 |
+
"RN101": "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt",
|
21 |
+
"RN50x4": "https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt",
|
22 |
+
"ViT-B/32": "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt",
|
23 |
+
}
|
24 |
+
|
25 |
+
|
26 |
+
def _download(url: str, root: str = os.path.expanduser("~/.cache/clip")):
|
27 |
+
os.makedirs(root, exist_ok=True)
|
28 |
+
filename = os.path.basename(url)
|
29 |
+
|
30 |
+
expected_sha256 = url.split("/")[-2]
|
31 |
+
download_target = os.path.join(root, filename)
|
32 |
+
|
33 |
+
if os.path.exists(download_target) and not os.path.isfile(download_target):
|
34 |
+
raise RuntimeError(f"{download_target} exists and is not a regular file")
|
35 |
+
|
36 |
+
if os.path.isfile(download_target):
|
37 |
+
if hashlib.sha256(open(download_target, "rb").read()).hexdigest() == expected_sha256:
|
38 |
+
return download_target
|
39 |
+
else:
|
40 |
+
warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file")
|
41 |
+
|
42 |
+
with urllib.request.urlopen(url) as source, open(download_target, "wb") as output:
|
43 |
+
with tqdm(total=int(source.info().get("Content-Length")), ncols=80, unit='iB', unit_scale=True) as loop:
|
44 |
+
while True:
|
45 |
+
buffer = source.read(8192)
|
46 |
+
if not buffer:
|
47 |
+
break
|
48 |
+
|
49 |
+
output.write(buffer)
|
50 |
+
loop.update(len(buffer))
|
51 |
+
|
52 |
+
if hashlib.sha256(open(download_target, "rb").read()).hexdigest() != expected_sha256:
|
53 |
+
raise RuntimeError(f"Model has been downloaded but the SHA256 checksum does not not match")
|
54 |
+
|
55 |
+
return download_target
|
56 |
+
|
57 |
+
|
58 |
+
def _transform(n_px):
|
59 |
+
return Compose([
|
60 |
+
Resize(n_px, interpolation=Image.BICUBIC),
|
61 |
+
CenterCrop(n_px),
|
62 |
+
lambda image: image.convert("RGB"),
|
63 |
+
ToTensor(),
|
64 |
+
Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
|
65 |
+
])
|
66 |
+
|
67 |
+
|
68 |
+
def available_models() -> List[str]:
|
69 |
+
"""Returns the names of available CLIP models"""
|
70 |
+
return list(_MODELS.keys())
|
71 |
+
|
72 |
+
|
73 |
+
def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", jit=True):
|
74 |
+
"""Load a CLIP model
|
75 |
+
|
76 |
+
Parameters
|
77 |
+
----------
|
78 |
+
name : str
|
79 |
+
A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict
|
80 |
+
|
81 |
+
device : Union[str, torch.device]
|
82 |
+
The device to put the loaded model
|
83 |
+
|
84 |
+
jit : bool
|
85 |
+
Whether to load the optimized JIT model (default) or more hackable non-JIT model.
|
86 |
+
|
87 |
+
Returns
|
88 |
+
-------
|
89 |
+
model : torch.nn.Module
|
90 |
+
The CLIP model
|
91 |
+
|
92 |
+
preprocess : Callable[[PIL.Image], torch.Tensor]
|
93 |
+
A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input
|
94 |
+
"""
|
95 |
+
if name in _MODELS:
|
96 |
+
model_path = _download(_MODELS[name])
|
97 |
+
elif os.path.isfile(name):
|
98 |
+
model_path = name
|
99 |
+
else:
|
100 |
+
raise RuntimeError(f"Model {name} not found; available models = {available_models()}")
|
101 |
+
|
102 |
+
try:
|
103 |
+
# loading JIT archive
|
104 |
+
model = torch.jit.load(model_path, map_location=device if jit else "cpu").eval()
|
105 |
+
state_dict = None
|
106 |
+
except RuntimeError:
|
107 |
+
# loading saved state dict
|
108 |
+
if jit:
|
109 |
+
warnings.warn(f"File {model_path} is not a JIT archive. Loading as a state dict instead")
|
110 |
+
jit = False
|
111 |
+
state_dict = torch.load(model_path, map_location="cpu")
|
112 |
+
|
113 |
+
if not jit:
|
114 |
+
model = build_model(state_dict or model.state_dict()).to(device)
|
115 |
+
if str(device) == "cpu":
|
116 |
+
model.float()
|
117 |
+
return model, _transform(model.visual.input_resolution)
|
118 |
+
|
119 |
+
# patch the device names
|
120 |
+
device_holder = torch.jit.trace(lambda: torch.ones([]).to(torch.device(device)), example_inputs=[])
|
121 |
+
device_node = [n for n in device_holder.graph.findAllNodes("prim::Constant") if "Device" in repr(n)][-1]
|
122 |
+
|
123 |
+
def patch_device(module):
|
124 |
+
graphs = [module.graph] if hasattr(module, "graph") else []
|
125 |
+
if hasattr(module, "forward1"):
|
126 |
+
graphs.append(module.forward1.graph)
|
127 |
+
|
128 |
+
for graph in graphs:
|
129 |
+
for node in graph.findAllNodes("prim::Constant"):
|
130 |
+
if "value" in node.attributeNames() and str(node["value"]).startswith("cuda"):
|
131 |
+
node.copyAttributes(device_node)
|
132 |
+
|
133 |
+
model.apply(patch_device)
|
134 |
+
patch_device(model.encode_image)
|
135 |
+
patch_device(model.encode_text)
|
136 |
+
|
137 |
+
# patch dtype to float32 on CPU
|
138 |
+
if str(device) == "cpu":
|
139 |
+
float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[])
|
140 |
+
float_input = list(float_holder.graph.findNode("aten::to").inputs())[1]
|
141 |
+
float_node = float_input.node()
|
142 |
+
|
143 |
+
def patch_float(module):
|
144 |
+
graphs = [module.graph] if hasattr(module, "graph") else []
|
145 |
+
if hasattr(module, "forward1"):
|
146 |
+
graphs.append(module.forward1.graph)
|
147 |
+
|
148 |
+
for graph in graphs:
|
149 |
+
for node in graph.findAllNodes("aten::to"):
|
150 |
+
inputs = list(node.inputs())
|
151 |
+
for i in [1, 2]: # dtype can be the second or third argument to aten::to()
|
152 |
+
if inputs[i].node()["value"] == 5:
|
153 |
+
inputs[i].node().copyAttributes(float_node)
|
154 |
+
|
155 |
+
model.apply(patch_float)
|
156 |
+
patch_float(model.encode_image)
|
157 |
+
patch_float(model.encode_text)
|
158 |
+
|
159 |
+
model.float()
|
160 |
+
|
161 |
+
return model, _transform(model.input_resolution.item())
|
162 |
+
|
163 |
+
|
164 |
+
def tokenize(texts: Union[str, List[str]], context_length: int = 77) -> torch.LongTensor:
|
165 |
+
"""
|
166 |
+
Returns the tokenized representation of given input string(s)
|
167 |
+
|
168 |
+
Parameters
|
169 |
+
----------
|
170 |
+
texts : Union[str, List[str]]
|
171 |
+
An input string or a list of input strings to tokenize
|
172 |
+
|
173 |
+
context_length : int
|
174 |
+
The context length to use; all CLIP models use 77 as the context length
|
175 |
+
|
176 |
+
Returns
|
177 |
+
-------
|
178 |
+
A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length]
|
179 |
+
"""
|
180 |
+
if isinstance(texts, str):
|
181 |
+
texts = [texts]
|
182 |
+
|
183 |
+
sot_token = _tokenizer.encoder["<|startoftext|>"]
|
184 |
+
eot_token = _tokenizer.encoder["<|endoftext|>"]
|
185 |
+
all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts]
|
186 |
+
result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)
|
187 |
+
|
188 |
+
for i, tokens in enumerate(all_tokens):
|
189 |
+
if len(tokens) > context_length:
|
190 |
+
raise RuntimeError(f"Input {texts[i]} is too long for context length {context_length}")
|
191 |
+
result[i, :len(tokens)] = torch.tensor(tokens)
|
192 |
+
|
193 |
+
return result
|
clip/model.py
ADDED
@@ -0,0 +1,437 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from collections import OrderedDict
|
2 |
+
from typing import Tuple, Union
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import torch.nn.functional as F
|
6 |
+
from torch import nn
|
7 |
+
|
8 |
+
|
9 |
+
class Bottleneck(nn.Module):
|
10 |
+
expansion = 4
|
11 |
+
|
12 |
+
def __init__(self, inplanes, planes, stride=1):
|
13 |
+
super().__init__()
|
14 |
+
|
15 |
+
# all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1
|
16 |
+
self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False)
|
17 |
+
self.bn1 = nn.BatchNorm2d(planes)
|
18 |
+
|
19 |
+
self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False)
|
20 |
+
self.bn2 = nn.BatchNorm2d(planes)
|
21 |
+
|
22 |
+
self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity()
|
23 |
+
|
24 |
+
self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False)
|
25 |
+
self.bn3 = nn.BatchNorm2d(planes * self.expansion)
|
26 |
+
|
27 |
+
self.relu = nn.ReLU(inplace=True)
|
28 |
+
self.downsample = None
|
29 |
+
self.stride = stride
|
30 |
+
|
31 |
+
if stride > 1 or inplanes != planes * Bottleneck.expansion:
|
32 |
+
# downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1
|
33 |
+
self.downsample = nn.Sequential(OrderedDict([
|
34 |
+
("-1", nn.AvgPool2d(stride)),
|
35 |
+
("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)),
|
36 |
+
("1", nn.BatchNorm2d(planes * self.expansion))
|
37 |
+
]))
|
38 |
+
|
39 |
+
def forward(self, x: torch.Tensor):
|
40 |
+
identity = x
|
41 |
+
|
42 |
+
out = self.relu(self.bn1(self.conv1(x)))
|
43 |
+
out = self.relu(self.bn2(self.conv2(out)))
|
44 |
+
out = self.avgpool(out)
|
45 |
+
out = self.bn3(self.conv3(out))
|
46 |
+
|
47 |
+
if self.downsample is not None:
|
48 |
+
identity = self.downsample(x)
|
49 |
+
|
50 |
+
out += identity
|
51 |
+
out = self.relu(out)
|
52 |
+
return out
|
53 |
+
|
54 |
+
|
55 |
+
class AttentionPool2d(nn.Module):
|
56 |
+
def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None):
|
57 |
+
super().__init__()
|
58 |
+
self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5)
|
59 |
+
self.k_proj = nn.Linear(embed_dim, embed_dim)
|
60 |
+
self.q_proj = nn.Linear(embed_dim, embed_dim)
|
61 |
+
self.v_proj = nn.Linear(embed_dim, embed_dim)
|
62 |
+
self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim)
|
63 |
+
self.num_heads = num_heads
|
64 |
+
|
65 |
+
def forward(self, x):
|
66 |
+
x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]).permute(2, 0, 1) # NCHW -> (HW)NC
|
67 |
+
# print(x.shape, self.positional_embedding.shape)
|
68 |
+
x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC
|
69 |
+
x = x + self.positional_embedding[0, :, None, :].to(x.dtype) # (HW+1)NC
|
70 |
+
x, _ = F.multi_head_attention_forward(
|
71 |
+
query=x, key=x, value=x,
|
72 |
+
embed_dim_to_check=x.shape[-1],
|
73 |
+
num_heads=self.num_heads,
|
74 |
+
q_proj_weight=self.q_proj.weight,
|
75 |
+
k_proj_weight=self.k_proj.weight,
|
76 |
+
v_proj_weight=self.v_proj.weight,
|
77 |
+
in_proj_weight=None,
|
78 |
+
in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]),
|
79 |
+
bias_k=None,
|
80 |
+
bias_v=None,
|
81 |
+
add_zero_attn=False,
|
82 |
+
dropout_p=0,
|
83 |
+
out_proj_weight=torch.ones_like(self.q_proj.weight),
|
84 |
+
out_proj_bias=torch.zeros_like(self.q_proj.bias),
|
85 |
+
# out_proj_weight=self.c_proj.weight,
|
86 |
+
# out_proj_bias=self.c_proj.bias,
|
87 |
+
use_separate_proj_weight=True,
|
88 |
+
training=self.training,
|
89 |
+
need_weights=False
|
90 |
+
)
|
91 |
+
|
92 |
+
return x[0]
|
93 |
+
|
94 |
+
|
95 |
+
class ModifiedResNet(nn.Module):
|
96 |
+
"""
|
97 |
+
A ResNet class that is similar to torchvision's but contains the following changes:
|
98 |
+
- There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool.
|
99 |
+
- Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1
|
100 |
+
- The final pooling layer is a QKV attention instead of an average pool
|
101 |
+
"""
|
102 |
+
|
103 |
+
def __init__(self, layers, output_dim, heads, input_resolution=224, width=64):
|
104 |
+
super().__init__()
|
105 |
+
self.output_dim = output_dim
|
106 |
+
self.input_resolution = input_resolution
|
107 |
+
|
108 |
+
# the 3-layer stem
|
109 |
+
self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False)
|
110 |
+
self.bn1 = nn.BatchNorm2d(width // 2)
|
111 |
+
self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False)
|
112 |
+
self.bn2 = nn.BatchNorm2d(width // 2)
|
113 |
+
self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False)
|
114 |
+
self.bn3 = nn.BatchNorm2d(width)
|
115 |
+
self.avgpool = nn.AvgPool2d(2)
|
116 |
+
self.relu = nn.ReLU(inplace=True)
|
117 |
+
|
118 |
+
# residual layers
|
119 |
+
self._inplanes = width # this is a *mutable* variable used during construction
|
120 |
+
self.layer1 = self._make_layer(width, layers[0])
|
121 |
+
self.layer2 = self._make_layer(width * 2, layers[1], stride=2)
|
122 |
+
self.layer3 = self._make_layer(width * 4, layers[2], stride=2)
|
123 |
+
self.layer4 = self._make_layer(width * 8, layers[3], stride=2)
|
124 |
+
|
125 |
+
embed_dim = width * 32 # the ResNet feature dimension
|
126 |
+
self.attnpool = AttentionPool2d(input_resolution // 32, embed_dim, heads, output_dim)
|
127 |
+
|
128 |
+
def _make_layer(self, planes, blocks, stride=1):
|
129 |
+
layers = [Bottleneck(self._inplanes, planes, stride)]
|
130 |
+
|
131 |
+
self._inplanes = planes * Bottleneck.expansion
|
132 |
+
for _ in range(1, blocks):
|
133 |
+
layers.append(Bottleneck(self._inplanes, planes))
|
134 |
+
|
135 |
+
return nn.Sequential(*layers)
|
136 |
+
|
137 |
+
def forward(self, x):
|
138 |
+
def stem(x):
|
139 |
+
for conv, bn in [(self.conv1, self.bn1), (self.conv2, self.bn2), (self.conv3, self.bn3)]:
|
140 |
+
x = self.relu(bn(conv(x)))
|
141 |
+
x = self.avgpool(x)
|
142 |
+
return x
|
143 |
+
|
144 |
+
x = x.type(self.conv1.weight.dtype)
|
145 |
+
x = stem(x)
|
146 |
+
x = self.layer1(x)
|
147 |
+
x = self.layer2(x)
|
148 |
+
x = self.layer3(x)
|
149 |
+
x = self.layer4(x)
|
150 |
+
# print(x.shape)
|
151 |
+
# x = self.attnpool(x)
|
152 |
+
attnpool = self.attnpool(x)
|
153 |
+
|
154 |
+
return (x, attnpool)
|
155 |
+
|
156 |
+
|
157 |
+
class LayerNorm(nn.LayerNorm):
|
158 |
+
"""Subclass torch's LayerNorm to handle fp16."""
|
159 |
+
|
160 |
+
def forward(self, x: torch.Tensor):
|
161 |
+
orig_type = x.dtype
|
162 |
+
ret = super().forward(x.type(torch.float32))
|
163 |
+
return ret.type(orig_type)
|
164 |
+
|
165 |
+
|
166 |
+
class QuickGELU(nn.Module):
|
167 |
+
def forward(self, x: torch.Tensor):
|
168 |
+
return x * torch.sigmoid(1.702 * x)
|
169 |
+
|
170 |
+
|
171 |
+
class ResidualAttentionBlock(nn.Module):
|
172 |
+
def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None):
|
173 |
+
super().__init__()
|
174 |
+
|
175 |
+
self.attn = nn.MultiheadAttention(d_model, n_head)
|
176 |
+
self.ln_1 = LayerNorm(d_model)
|
177 |
+
self.mlp = nn.Sequential(OrderedDict([
|
178 |
+
("c_fc", nn.Linear(d_model, d_model * 4)),
|
179 |
+
("gelu", QuickGELU()),
|
180 |
+
("c_proj", nn.Linear(d_model * 4, d_model))
|
181 |
+
]))
|
182 |
+
self.ln_2 = LayerNorm(d_model)
|
183 |
+
self.attn_mask = attn_mask
|
184 |
+
|
185 |
+
def attention(self, x: torch.Tensor):
|
186 |
+
self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None
|
187 |
+
return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0]
|
188 |
+
|
189 |
+
def forward(self, x: torch.Tensor):
|
190 |
+
x = x + self.attention(self.ln_1(x))
|
191 |
+
x = x + self.mlp(self.ln_2(x))
|
192 |
+
return x
|
193 |
+
|
194 |
+
|
195 |
+
class Transformer(nn.Module):
|
196 |
+
def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None):
|
197 |
+
super().__init__()
|
198 |
+
self.width = width
|
199 |
+
self.layers = layers
|
200 |
+
self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask) for _ in range(layers)])
|
201 |
+
|
202 |
+
def forward(self, x: torch.Tensor):
|
203 |
+
return self.resblocks(x)
|
204 |
+
|
205 |
+
|
206 |
+
class VisualTransformer(nn.Module):
|
207 |
+
def __init__(self, input_resolution: int, patch_size: int, width: int, layers: int, heads: int, output_dim: int):
|
208 |
+
super().__init__()
|
209 |
+
self.input_resolution = input_resolution
|
210 |
+
self.output_dim = output_dim
|
211 |
+
self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False)
|
212 |
+
|
213 |
+
scale = width ** -0.5
|
214 |
+
self.class_embedding = nn.Parameter(scale * torch.randn(width))
|
215 |
+
self.positional_embedding = nn.Parameter(scale * torch.randn((input_resolution // patch_size) ** 2 + 1, width))
|
216 |
+
self.ln_pre = LayerNorm(width)
|
217 |
+
|
218 |
+
self.transformer = Transformer(width, layers, heads)
|
219 |
+
|
220 |
+
self.ln_post = LayerNorm(width)
|
221 |
+
self.proj = nn.Parameter(scale * torch.randn(width, output_dim))
|
222 |
+
|
223 |
+
def forward(self, x: torch.Tensor):
|
224 |
+
x = self.conv1(x) # shape = [*, width, grid, grid]
|
225 |
+
x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2]
|
226 |
+
x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
|
227 |
+
x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width]
|
228 |
+
x = x + self.positional_embedding.to(x.dtype)
|
229 |
+
x = self.ln_pre(x)
|
230 |
+
|
231 |
+
x = x.permute(1, 0, 2) # NLD -> LND
|
232 |
+
x = self.transformer(x)
|
233 |
+
x = x.permute(1, 0, 2) # LND -> NLD
|
234 |
+
|
235 |
+
# x = self.ln_post(x[:, 0, :])
|
236 |
+
|
237 |
+
x = self.ln_post(x)
|
238 |
+
# if self.proj is not None:
|
239 |
+
# x = x @ self.proj
|
240 |
+
|
241 |
+
return x
|
242 |
+
|
243 |
+
|
244 |
+
class CLIP(nn.Module):
|
245 |
+
def __init__(self,
|
246 |
+
embed_dim: int,
|
247 |
+
# vision
|
248 |
+
image_resolution: int,
|
249 |
+
vision_layers: Union[Tuple[int, int, int, int], int],
|
250 |
+
vision_width: int,
|
251 |
+
vision_patch_size: int,
|
252 |
+
# text
|
253 |
+
context_length: int,
|
254 |
+
vocab_size: int,
|
255 |
+
transformer_width: int,
|
256 |
+
transformer_heads: int,
|
257 |
+
transformer_layers: int
|
258 |
+
):
|
259 |
+
super().__init__()
|
260 |
+
|
261 |
+
self.context_length = context_length
|
262 |
+
|
263 |
+
if isinstance(vision_layers, (tuple, list)):
|
264 |
+
vision_heads = vision_width * 32 // 64
|
265 |
+
self.visual = ModifiedResNet(
|
266 |
+
layers=vision_layers,
|
267 |
+
output_dim=embed_dim,
|
268 |
+
heads=vision_heads,
|
269 |
+
input_resolution=image_resolution,
|
270 |
+
width=vision_width
|
271 |
+
)
|
272 |
+
else:
|
273 |
+
vision_heads = vision_width // 64
|
274 |
+
self.visual = VisualTransformer(
|
275 |
+
input_resolution=image_resolution,
|
276 |
+
patch_size=vision_patch_size,
|
277 |
+
width=vision_width,
|
278 |
+
layers=vision_layers,
|
279 |
+
heads=vision_heads,
|
280 |
+
output_dim=embed_dim
|
281 |
+
)
|
282 |
+
|
283 |
+
self.transformer = Transformer(
|
284 |
+
width=transformer_width,
|
285 |
+
layers=transformer_layers,
|
286 |
+
heads=transformer_heads,
|
287 |
+
attn_mask=self.build_attention_mask()
|
288 |
+
)
|
289 |
+
|
290 |
+
self.vocab_size = vocab_size
|
291 |
+
self.token_embedding = nn.Embedding(vocab_size, transformer_width)
|
292 |
+
self.positional_embedding = nn.Parameter(torch.empty(self.context_length, transformer_width))
|
293 |
+
self.ln_final = LayerNorm(transformer_width)
|
294 |
+
|
295 |
+
self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim))
|
296 |
+
self.logit_scale = nn.Parameter(torch.ones([]))
|
297 |
+
|
298 |
+
self.initialize_parameters()
|
299 |
+
|
300 |
+
def initialize_parameters(self):
|
301 |
+
nn.init.normal_(self.token_embedding.weight, std=0.02)
|
302 |
+
nn.init.normal_(self.positional_embedding, std=0.01)
|
303 |
+
|
304 |
+
if isinstance(self.visual, ModifiedResNet):
|
305 |
+
if self.visual.attnpool is not None:
|
306 |
+
std = self.visual.attnpool.c_proj.in_features ** -0.5
|
307 |
+
nn.init.normal_(self.visual.attnpool.q_proj.weight, std=std)
|
308 |
+
nn.init.normal_(self.visual.attnpool.k_proj.weight, std=std)
|
309 |
+
nn.init.normal_(self.visual.attnpool.v_proj.weight, std=std)
|
310 |
+
nn.init.normal_(self.visual.attnpool.c_proj.weight, std=std)
|
311 |
+
|
312 |
+
for resnet_block in [self.visual.layer1, self.visual.layer2, self.visual.layer3, self.visual.layer4]:
|
313 |
+
for name, param in resnet_block.named_parameters():
|
314 |
+
if name.endswith("bn3.weight"):
|
315 |
+
nn.init.zeros_(param)
|
316 |
+
|
317 |
+
proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5)
|
318 |
+
attn_std = self.transformer.width ** -0.5
|
319 |
+
fc_std = (2 * self.transformer.width) ** -0.5
|
320 |
+
for block in self.transformer.resblocks:
|
321 |
+
nn.init.normal_(block.attn.in_proj_weight, std=attn_std)
|
322 |
+
nn.init.normal_(block.attn.out_proj.weight, std=proj_std)
|
323 |
+
nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)
|
324 |
+
nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)
|
325 |
+
|
326 |
+
if self.text_projection is not None:
|
327 |
+
nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5)
|
328 |
+
|
329 |
+
def build_attention_mask(self):
|
330 |
+
# lazily create causal attention mask, with full attention between the vision tokens
|
331 |
+
# pytorch uses additive attention mask; fill with -inf
|
332 |
+
mask = torch.empty(self.context_length, self.context_length)
|
333 |
+
mask.fill_(float("-inf"))
|
334 |
+
mask.triu_(1) # zero out the lower diagonal
|
335 |
+
return mask
|
336 |
+
|
337 |
+
@property
|
338 |
+
def dtype(self):
|
339 |
+
return self.visual.conv1.weight.dtype
|
340 |
+
|
341 |
+
def encode_image(self, image):
|
342 |
+
return self.visual(image.type(self.dtype))
|
343 |
+
|
344 |
+
def encode_text(self, text):
|
345 |
+
x = self.token_embedding(text).type(self.dtype) # [batch_size, n_ctx, d_model]
|
346 |
+
|
347 |
+
x = x + self.positional_embedding.type(self.dtype)
|
348 |
+
x = x.permute(1, 0, 2) # NLD -> LND
|
349 |
+
x = self.transformer(x)
|
350 |
+
x = x.permute(1, 0, 2) # LND -> NLD
|
351 |
+
x = self.ln_final(x).type(self.dtype)
|
352 |
+
|
353 |
+
# x.shape = [batch_size, n_ctx, transformer.width]
|
354 |
+
# take features from the eot embedding (eot_token is the highest number in each sequence)
|
355 |
+
x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection
|
356 |
+
|
357 |
+
return x
|
358 |
+
|
359 |
+
def forward(self, image, text):
|
360 |
+
image_features = self.encode_image(image)
|
361 |
+
text_features = self.encode_text(text)
|
362 |
+
|
363 |
+
# normalized features
|
364 |
+
image_features = image_features / image_features.norm(dim=-1, keepdim=True)
|
365 |
+
text_features = text_features / text_features.norm(dim=-1, keepdim=True)
|
366 |
+
|
367 |
+
# cosine similarity as logits
|
368 |
+
logit_scale = self.logit_scale.exp()
|
369 |
+
logits_per_image = logit_scale * image_features @ text_features.t()
|
370 |
+
logits_per_text = logit_scale * text_features @ image_features.t()
|
371 |
+
|
372 |
+
# shape = [global_batch_size, global_batch_size]
|
373 |
+
return logits_per_image, logits_per_text
|
374 |
+
|
375 |
+
|
376 |
+
def convert_weights(model: nn.Module):
|
377 |
+
"""Convert applicable model parameters to fp16"""
|
378 |
+
|
379 |
+
def _convert_weights_to_fp16(l):
|
380 |
+
if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)):
|
381 |
+
l.weight.data = l.weight.data.half()
|
382 |
+
if l.bias is not None:
|
383 |
+
l.bias.data = l.bias.data.half()
|
384 |
+
|
385 |
+
if isinstance(l, nn.MultiheadAttention):
|
386 |
+
for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]:
|
387 |
+
tensor = getattr(l, attr)
|
388 |
+
if tensor is not None:
|
389 |
+
tensor.data = tensor.data.half()
|
390 |
+
|
391 |
+
for name in ["text_projection", "proj"]:
|
392 |
+
if hasattr(l, name):
|
393 |
+
attr = getattr(l, name)
|
394 |
+
if attr is not None:
|
395 |
+
attr.data = attr.data.half()
|
396 |
+
|
397 |
+
model.apply(_convert_weights_to_fp16)
|
398 |
+
|
399 |
+
|
400 |
+
def build_model(state_dict: dict):
|
401 |
+
vit = "visual.proj" in state_dict
|
402 |
+
|
403 |
+
if vit:
|
404 |
+
vision_width = state_dict["visual.conv1.weight"].shape[0]
|
405 |
+
vision_layers = len([k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")])
|
406 |
+
vision_patch_size = state_dict["visual.conv1.weight"].shape[-1]
|
407 |
+
grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5)
|
408 |
+
image_resolution = vision_patch_size * grid_size
|
409 |
+
else:
|
410 |
+
counts: list = [len(set(k.split(".")[2] for k in state_dict if k.startswith(f"visual.layer{b}"))) for b in [1, 2, 3, 4]]
|
411 |
+
vision_layers = tuple(counts)
|
412 |
+
vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0]
|
413 |
+
output_width = round((state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5)
|
414 |
+
vision_patch_size = None
|
415 |
+
assert output_width ** 2 + 1 == state_dict["visual.attnpool.positional_embedding"].shape[0]
|
416 |
+
image_resolution = output_width * 32
|
417 |
+
|
418 |
+
embed_dim = state_dict["text_projection"].shape[1]
|
419 |
+
context_length = state_dict["positional_embedding"].shape[0]
|
420 |
+
vocab_size = state_dict["token_embedding.weight"].shape[0]
|
421 |
+
transformer_width = state_dict["ln_final.weight"].shape[0]
|
422 |
+
transformer_heads = transformer_width // 64
|
423 |
+
transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith(f"transformer.resblocks")))
|
424 |
+
|
425 |
+
model = CLIP(
|
426 |
+
embed_dim,
|
427 |
+
image_resolution, vision_layers, vision_width, vision_patch_size,
|
428 |
+
context_length, vocab_size, transformer_width, transformer_heads, transformer_layers
|
429 |
+
)
|
430 |
+
|
431 |
+
for key in ["input_resolution", "context_length", "vocab_size"]:
|
432 |
+
if key in state_dict:
|
433 |
+
del state_dict[key]
|
434 |
+
|
435 |
+
convert_weights(model)
|
436 |
+
model.load_state_dict(state_dict)
|
437 |
+
return model.eval()
|
clip/simple_tokenizer.py
ADDED
@@ -0,0 +1,132 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gzip
|
2 |
+
import html
|
3 |
+
import os
|
4 |
+
from functools import lru_cache
|
5 |
+
|
6 |
+
import ftfy
|
7 |
+
import regex as re
|
8 |
+
|
9 |
+
|
10 |
+
@lru_cache()
|
11 |
+
def default_bpe():
|
12 |
+
return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz")
|
13 |
+
|
14 |
+
|
15 |
+
@lru_cache()
|
16 |
+
def bytes_to_unicode():
|
17 |
+
"""
|
18 |
+
Returns list of utf-8 byte and a corresponding list of unicode strings.
|
19 |
+
The reversible bpe codes work on unicode strings.
|
20 |
+
This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
|
21 |
+
When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
|
22 |
+
This is a signficant percentage of your normal, say, 32K bpe vocab.
|
23 |
+
To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
|
24 |
+
And avoids mapping to whitespace/control characters the bpe code barfs on.
|
25 |
+
"""
|
26 |
+
bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1))
|
27 |
+
cs = bs[:]
|
28 |
+
n = 0
|
29 |
+
for b in range(2**8):
|
30 |
+
if b not in bs:
|
31 |
+
bs.append(b)
|
32 |
+
cs.append(2**8+n)
|
33 |
+
n += 1
|
34 |
+
cs = [chr(n) for n in cs]
|
35 |
+
return dict(zip(bs, cs))
|
36 |
+
|
37 |
+
|
38 |
+
def get_pairs(word):
|
39 |
+
"""Return set of symbol pairs in a word.
|
40 |
+
Word is represented as tuple of symbols (symbols being variable-length strings).
|
41 |
+
"""
|
42 |
+
pairs = set()
|
43 |
+
prev_char = word[0]
|
44 |
+
for char in word[1:]:
|
45 |
+
pairs.add((prev_char, char))
|
46 |
+
prev_char = char
|
47 |
+
return pairs
|
48 |
+
|
49 |
+
|
50 |
+
def basic_clean(text):
|
51 |
+
text = ftfy.fix_text(text)
|
52 |
+
text = html.unescape(html.unescape(text))
|
53 |
+
return text.strip()
|
54 |
+
|
55 |
+
|
56 |
+
def whitespace_clean(text):
|
57 |
+
text = re.sub(r'\s+', ' ', text)
|
58 |
+
text = text.strip()
|
59 |
+
return text
|
60 |
+
|
61 |
+
|
62 |
+
class SimpleTokenizer(object):
|
63 |
+
def __init__(self, bpe_path: str = default_bpe()):
|
64 |
+
self.byte_encoder = bytes_to_unicode()
|
65 |
+
self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
|
66 |
+
merges = gzip.open(bpe_path).read().decode("utf-8").split('\n')
|
67 |
+
merges = merges[1:49152-256-2+1]
|
68 |
+
merges = [tuple(merge.split()) for merge in merges]
|
69 |
+
vocab = list(bytes_to_unicode().values())
|
70 |
+
vocab = vocab + [v+'</w>' for v in vocab]
|
71 |
+
for merge in merges:
|
72 |
+
vocab.append(''.join(merge))
|
73 |
+
vocab.extend(['<|startoftext|>', '<|endoftext|>'])
|
74 |
+
self.encoder = dict(zip(vocab, range(len(vocab))))
|
75 |
+
self.decoder = {v: k for k, v in self.encoder.items()}
|
76 |
+
self.bpe_ranks = dict(zip(merges, range(len(merges))))
|
77 |
+
self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'}
|
78 |
+
self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE)
|
79 |
+
|
80 |
+
def bpe(self, token):
|
81 |
+
if token in self.cache:
|
82 |
+
return self.cache[token]
|
83 |
+
word = tuple(token[:-1]) + ( token[-1] + '</w>',)
|
84 |
+
pairs = get_pairs(word)
|
85 |
+
|
86 |
+
if not pairs:
|
87 |
+
return token+'</w>'
|
88 |
+
|
89 |
+
while True:
|
90 |
+
bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf')))
|
91 |
+
if bigram not in self.bpe_ranks:
|
92 |
+
break
|
93 |
+
first, second = bigram
|
94 |
+
new_word = []
|
95 |
+
i = 0
|
96 |
+
while i < len(word):
|
97 |
+
try:
|
98 |
+
j = word.index(first, i)
|
99 |
+
new_word.extend(word[i:j])
|
100 |
+
i = j
|
101 |
+
except:
|
102 |
+
new_word.extend(word[i:])
|
103 |
+
break
|
104 |
+
|
105 |
+
if word[i] == first and i < len(word)-1 and word[i+1] == second:
|
106 |
+
new_word.append(first+second)
|
107 |
+
i += 2
|
108 |
+
else:
|
109 |
+
new_word.append(word[i])
|
110 |
+
i += 1
|
111 |
+
new_word = tuple(new_word)
|
112 |
+
word = new_word
|
113 |
+
if len(word) == 1:
|
114 |
+
break
|
115 |
+
else:
|
116 |
+
pairs = get_pairs(word)
|
117 |
+
word = ' '.join(word)
|
118 |
+
self.cache[token] = word
|
119 |
+
return word
|
120 |
+
|
121 |
+
def encode(self, text):
|
122 |
+
bpe_tokens = []
|
123 |
+
text = whitespace_clean(basic_clean(text)).lower()
|
124 |
+
for token in re.findall(self.pat, text):
|
125 |
+
token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))
|
126 |
+
bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' '))
|
127 |
+
return bpe_tokens
|
128 |
+
|
129 |
+
def decode(self, tokens):
|
130 |
+
text = ''.join([self.decoder[token] for token in tokens])
|
131 |
+
text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('</w>', ' ')
|
132 |
+
return text
|
configs/phase1/FineCapEval_clipRN50_mle.yml
ADDED
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
caption_model: transformer
|
2 |
+
noamopt: true
|
3 |
+
noamopt_warmup: 20000
|
4 |
+
label_smoothing: 0.0
|
5 |
+
input_json: data/FineCapEval.json
|
6 |
+
input_label_h5: none
|
7 |
+
input_fc_dir: data/FineCapEval_clip_RN50_fc
|
8 |
+
input_att_dir: data/FineCapEval_clip_RN50_att
|
9 |
+
input_clipscore_vis_dir: data/FineCapEval_clipscore_vis
|
10 |
+
|
11 |
+
seq_per_img: 5
|
12 |
+
batch_size: 200
|
13 |
+
learning_rate: 0.0005
|
14 |
+
|
15 |
+
checkpoint_path: ./save/clipRN50_mle/clipRN50_mle
|
16 |
+
|
17 |
+
# clip_load_path: '/scratch-space/retrieval/save/clip_negative_text/clip_negative_text-epoch=10.ckpt'
|
18 |
+
|
19 |
+
# Notice: because I'm to lazy, I reuse the option name for RNNs to set the hyperparameters for transformer:
|
20 |
+
# N=num_layers
|
21 |
+
# d_model=input_encoding_size
|
22 |
+
# d_ff=rnn_size
|
23 |
+
|
24 |
+
# will be ignored
|
25 |
+
num_layers: 6
|
26 |
+
input_encoding_size: 512
|
27 |
+
rnn_size: 2048
|
28 |
+
|
29 |
+
# Transformer config
|
30 |
+
N_enc: 6
|
31 |
+
N_dec: 6
|
32 |
+
d_model: 512
|
33 |
+
d_ff: 2048
|
34 |
+
num_att_heads: 8
|
35 |
+
dropout: 0.1
|
36 |
+
|
37 |
+
|
38 |
+
learning_rate_decay_start: 0
|
39 |
+
scheduled_sampling_start: -1
|
40 |
+
save_checkpoint_every: 3000
|
41 |
+
language_eval: 1
|
42 |
+
val_images_use: 5000
|
43 |
+
max_epochs: 15
|
44 |
+
train_sample_n: 5
|
45 |
+
|
46 |
+
REFORWARD: false
|
47 |
+
|
48 |
+
# _BASE_: transformer.yml
|
49 |
+
reduce_on_plateau: false
|
50 |
+
noamopt: false
|
51 |
+
learning_rate: 0.000005
|
52 |
+
learning_rate_decay_start: -1
|
53 |
+
|
54 |
+
self_critical_after: 15
|
55 |
+
max_epochs: 50
|
56 |
+
|
57 |
+
verbose: false
|
58 |
+
precision: 32
|
59 |
+
|
60 |
+
use_clipscore: false
|
configs/phase1/clipRN50_mle.yml
ADDED
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
caption_model: transformer
|
2 |
+
noamopt: true
|
3 |
+
# noamopt: false
|
4 |
+
noamopt_warmup: 20000
|
5 |
+
label_smoothing: 0.0
|
6 |
+
input_json: data/cocotalk.json
|
7 |
+
input_label_h5: data/cocotalk_label.h5
|
8 |
+
input_fc_dir: data/cocotalk_clip_RN50_fc
|
9 |
+
input_att_dir: data/cocotalk_clip_RN50_att
|
10 |
+
input_clipscore_vis_dir: data/cocotalk_clipscore_vis
|
11 |
+
seq_per_img: 5
|
12 |
+
# batch_size: 600
|
13 |
+
batch_size: 200
|
14 |
+
|
15 |
+
learning_rate: 0.0005
|
16 |
+
|
17 |
+
# checkpoint_path: ./save/trans_clip_rn50_sc_pl
|
18 |
+
checkpoint_path: save/clipRN50_mle/clipRN50_mle
|
19 |
+
|
20 |
+
# Notice: because I'm to lazy, I reuse the option name for RNNs to set the hyperparameters for transformer:
|
21 |
+
# N=num_layers
|
22 |
+
# d_model=input_encoding_size
|
23 |
+
# d_ff=rnn_size
|
24 |
+
|
25 |
+
# will be ignored
|
26 |
+
num_layers: 6
|
27 |
+
input_encoding_size: 512
|
28 |
+
rnn_size: 2048
|
29 |
+
|
30 |
+
# Transformer config
|
31 |
+
N_enc: 6
|
32 |
+
N_dec: 6
|
33 |
+
d_model: 512
|
34 |
+
d_ff: 2048
|
35 |
+
num_att_heads: 8
|
36 |
+
dropout: 0.1
|
37 |
+
|
38 |
+
|
39 |
+
learning_rate_decay_start: 0
|
40 |
+
scheduled_sampling_start: -1
|
41 |
+
save_checkpoint_every: 3000
|
42 |
+
language_eval: 1
|
43 |
+
val_images_use: 5000
|
44 |
+
# max_epochs: 15
|
45 |
+
max_epochs: 25
|
46 |
+
train_sample_n: 5
|
47 |
+
|
48 |
+
REFORWARD: false
|
49 |
+
|
50 |
+
|
51 |
+
verbose: false
|
52 |
+
precision: 16
|
configs/phase1/transformer.yml
ADDED
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
caption_model: transformer
|
2 |
+
noamopt: true
|
3 |
+
noamopt_warmup: 20000
|
4 |
+
label_smoothing: 0.0
|
5 |
+
input_json: data/cocotalk.json
|
6 |
+
input_label_h5: data/cocotalk_label.h5
|
7 |
+
input_att_dir: data/cocotalk_att
|
8 |
+
seq_per_img: 5
|
9 |
+
batch_size: 10
|
10 |
+
learning_rate: 0.0005
|
11 |
+
|
12 |
+
checkpoint_path: ./save/trans_rn50_sc
|
13 |
+
|
14 |
+
# Notice: because I'm to lazy, I reuse the option name for RNNs to set the hyperparameters for transformer:
|
15 |
+
# N=num_layers
|
16 |
+
# d_model=input_encoding_size
|
17 |
+
# d_ff=rnn_size
|
18 |
+
|
19 |
+
# will be ignored
|
20 |
+
num_layers: 6
|
21 |
+
input_encoding_size: 512
|
22 |
+
rnn_size: 2048
|
23 |
+
|
24 |
+
# Transformer config
|
25 |
+
N_enc: 6
|
26 |
+
N_dec: 6
|
27 |
+
d_model: 512
|
28 |
+
d_ff: 2048
|
29 |
+
num_att_heads: 8
|
30 |
+
dropout: 0.1
|
31 |
+
|
32 |
+
|
33 |
+
learning_rate_decay_start: 0
|
34 |
+
scheduled_sampling_start: -1
|
35 |
+
save_checkpoint_every: 3000
|
36 |
+
language_eval: 1
|
37 |
+
val_images_use: 5000
|
38 |
+
max_epochs: 15
|
39 |
+
train_sample_n: 5
|
40 |
+
|
41 |
+
REFORWARD: false
|
configs/phase2/FineCapEval_clipRN50_cider.yml
ADDED
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
caption_model: transformer
|
2 |
+
noamopt: true
|
3 |
+
noamopt_warmup: 20000
|
4 |
+
label_smoothing: 0.0
|
5 |
+
input_json: data/FineCapEval.json
|
6 |
+
input_label_h5: none
|
7 |
+
input_fc_dir: data/FineCapEval_clip_RN50_fc
|
8 |
+
input_att_dir: data/FineCapEval_clip_RN50_att
|
9 |
+
input_clipscore_vis_dir: data/FineCapEval_clipscore_vis
|
10 |
+
|
11 |
+
seq_per_img: 5
|
12 |
+
batch_size: 200
|
13 |
+
learning_rate: 0.0005
|
14 |
+
|
15 |
+
checkpoint_path: ./save/clipRN50_cider/clipRN50_cider
|
16 |
+
|
17 |
+
# clip_load_path: '/scratch-space/retrieval/save/clip_negative_text/clip_negative_text-epoch=10.ckpt'
|
18 |
+
|
19 |
+
# Notice: because I'm to lazy, I reuse the option name for RNNs to set the hyperparameters for transformer:
|
20 |
+
# N=num_layers
|
21 |
+
# d_model=input_encoding_size
|
22 |
+
# d_ff=rnn_size
|
23 |
+
|
24 |
+
# will be ignored
|
25 |
+
num_layers: 6
|
26 |
+
input_encoding_size: 512
|
27 |
+
rnn_size: 2048
|
28 |
+
|
29 |
+
# Transformer config
|
30 |
+
N_enc: 6
|
31 |
+
N_dec: 6
|
32 |
+
d_model: 512
|
33 |
+
d_ff: 2048
|
34 |
+
num_att_heads: 8
|
35 |
+
dropout: 0.1
|
36 |
+
|
37 |
+
|
38 |
+
learning_rate_decay_start: 0
|
39 |
+
scheduled_sampling_start: -1
|
40 |
+
save_checkpoint_every: 3000
|
41 |
+
language_eval: 1
|
42 |
+
val_images_use: 5000
|
43 |
+
max_epochs: 15
|
44 |
+
train_sample_n: 5
|
45 |
+
|
46 |
+
REFORWARD: false
|
47 |
+
|
48 |
+
# _BASE_: transformer.yml
|
49 |
+
reduce_on_plateau: false
|
50 |
+
noamopt: false
|
51 |
+
learning_rate: 0.000005
|
52 |
+
learning_rate_decay_start: -1
|
53 |
+
|
54 |
+
self_critical_after: 15
|
55 |
+
max_epochs: 50
|
56 |
+
|
57 |
+
verbose: false
|
58 |
+
precision: 32
|
59 |
+
|
60 |
+
# use_clipscore: true
|
61 |
+
use_clipscore: false
|
configs/phase2/FineCapEval_clipRN50_cider_clips.yml
ADDED
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
caption_model: transformer
|
2 |
+
noamopt: true
|
3 |
+
noamopt_warmup: 20000
|
4 |
+
label_smoothing: 0.0
|
5 |
+
input_json: data/FineCapEval.json
|
6 |
+
input_label_h5: none
|
7 |
+
input_fc_dir: data/FineCapEval_clip_RN50_fc
|
8 |
+
input_att_dir: data/FineCapEval_clip_RN50_att
|
9 |
+
input_clipscore_vis_dir: data/FineCapEval_clipscore_vis
|
10 |
+
|
11 |
+
seq_per_img: 5
|
12 |
+
batch_size: 200
|
13 |
+
learning_rate: 0.0005
|
14 |
+
|
15 |
+
checkpoint_path: ./save/clipRN50_cider_clips/clipRN50_cider_clips
|
16 |
+
|
17 |
+
# clip_load_path: '/scratch-space/retrieval/save/clip_negative_text/clip_negative_text-epoch=10.ckpt'
|
18 |
+
|
19 |
+
# Notice: because I'm to lazy, I reuse the option name for RNNs to set the hyperparameters for transformer:
|
20 |
+
# N=num_layers
|
21 |
+
# d_model=input_encoding_size
|
22 |
+
# d_ff=rnn_size
|
23 |
+
|
24 |
+
# will be ignored
|
25 |
+
num_layers: 6
|
26 |
+
input_encoding_size: 512
|
27 |
+
rnn_size: 2048
|
28 |
+
|
29 |
+
# Transformer config
|
30 |
+
N_enc: 6
|
31 |
+
N_dec: 6
|
32 |
+
d_model: 512
|
33 |
+
d_ff: 2048
|
34 |
+
num_att_heads: 8
|
35 |
+
dropout: 0.1
|
36 |
+
|
37 |
+
|
38 |
+
learning_rate_decay_start: 0
|
39 |
+
scheduled_sampling_start: -1
|
40 |
+
save_checkpoint_every: 3000
|
41 |
+
language_eval: 1
|
42 |
+
val_images_use: 5000
|
43 |
+
max_epochs: 15
|
44 |
+
train_sample_n: 5
|
45 |
+
|
46 |
+
REFORWARD: false
|
47 |
+
|
48 |
+
# _BASE_: transformer.yml
|
49 |
+
reduce_on_plateau: false
|
50 |
+
noamopt: false
|
51 |
+
learning_rate: 0.000005
|
52 |
+
learning_rate_decay_start: -1
|
53 |
+
|
54 |
+
self_critical_after: 15
|
55 |
+
max_epochs: 50
|
56 |
+
|
57 |
+
verbose: false
|
58 |
+
precision: 32
|
59 |
+
|
60 |
+
# use_clipscore: true
|
61 |
+
use_clipscore: false
|
62 |
+
clipscore_reward_weight: 2.0
|
63 |
+
clipscore_mode: clip_s
|
64 |
+
|
65 |
+
use_multi_rewards: true
|
configs/phase2/FineCapEval_clipRN50_clips.yml
ADDED
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
caption_model: transformer
|
2 |
+
noamopt: true
|
3 |
+
noamopt_warmup: 20000
|
4 |
+
label_smoothing: 0.0
|
5 |
+
input_json: data/FineCapEval.json
|
6 |
+
input_label_h5: none
|
7 |
+
input_fc_dir: data/FineCapEval_clip_RN50_fc
|
8 |
+
input_att_dir: data/FineCapEval_clip_RN50_att
|
9 |
+
input_clipscore_vis_dir: data/FineCapEval_clipscore_vis
|
10 |
+
seq_per_img: 5
|
11 |
+
batch_size: 160
|
12 |
+
learning_rate: 0.0005
|
13 |
+
|
14 |
+
checkpoint_path: ./save/clipRN50_clips/clipRN50_clips
|
15 |
+
|
16 |
+
use_multi_rewards: false
|
17 |
+
use_grammar: false
|
18 |
+
use_grammar_baseline: false
|
19 |
+
# clip_load_path: '/scratch-space/retrieval/save/clip_negative_text/clip_negative_text-epoch=10.ckpt'
|
20 |
+
|
21 |
+
# Notice: because I'm to lazy, I reuse the option name for RNNs to set the hyperparameters for transformer:
|
22 |
+
# N=num_layers
|
23 |
+
# d_model=input_encoding_size
|
24 |
+
# d_ff=rnn_size
|
25 |
+
|
26 |
+
# will be ignored
|
27 |
+
num_layers: 6
|
28 |
+
input_encoding_size: 512
|
29 |
+
rnn_size: 2048
|
30 |
+
|
31 |
+
# Transformer config
|
32 |
+
N_enc: 6
|
33 |
+
N_dec: 6
|
34 |
+
d_model: 512
|
35 |
+
d_ff: 2048
|
36 |
+
num_att_heads: 8
|
37 |
+
dropout: 0.1
|
38 |
+
|
39 |
+
|
40 |
+
learning_rate_decay_start: 0
|
41 |
+
scheduled_sampling_start: -1
|
42 |
+
save_checkpoint_every: 3000
|
43 |
+
language_eval: 0
|
44 |
+
val_images_use: 5000
|
45 |
+
max_epochs: 15
|
46 |
+
train_sample_n: 5
|
47 |
+
|
48 |
+
REFORWARD: false
|
49 |
+
|
50 |
+
# _BASE_: transformer.yml
|
51 |
+
reduce_on_plateau: false
|
52 |
+
noamopt: false
|
53 |
+
learning_rate: 0.000005
|
54 |
+
learning_rate_decay_start: -1
|
55 |
+
|
56 |
+
self_critical_after: 15
|
57 |
+
max_epochs: 50
|
58 |
+
|
59 |
+
verbose: false
|
60 |
+
precision: 32
|
61 |
+
|
62 |
+
# use_clipscore: true
|
63 |
+
use_clipscore: false
|
64 |
+
clipscore_reward_weight: 2.0
|
configs/phase2/FineCapEval_clipRN50_clips_grammar.yml
ADDED
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
caption_model: transformer
|
2 |
+
noamopt: true
|
3 |
+
noamopt_warmup: 20000
|
4 |
+
label_smoothing: 0.0
|
5 |
+
input_json: data/FineCapEval.json
|
6 |
+
input_label_h5: none
|
7 |
+
input_fc_dir: data/FineCapEval_clip_RN50_fc
|
8 |
+
input_att_dir: data/FineCapEval_clip_RN50_att
|
9 |
+
input_clipscore_vis_dir: data/FineCapEval_clipscore_vis
|
10 |
+
seq_per_img: 5
|
11 |
+
batch_size: 160
|
12 |
+
learning_rate: 0.0005
|
13 |
+
|
14 |
+
checkpoint_path: ./save/clipRN50_clips_grammar/clipRN50_clips_grammar
|
15 |
+
|
16 |
+
use_multi_rewards: true
|
17 |
+
use_grammar: true
|
18 |
+
use_grammar_baseline: true
|
19 |
+
# clip_load_path: '/scratch-space/retrieval/save/clip_negative_text/clip_negative_text-epoch=10.ckpt'
|
20 |
+
|
21 |
+
# Notice: because I'm to lazy, I reuse the option name for RNNs to set the hyperparameters for transformer:
|
22 |
+
# N=num_layers
|
23 |
+
# d_model=input_encoding_size
|
24 |
+
# d_ff=rnn_size
|
25 |
+
|
26 |
+
# will be ignored
|
27 |
+
num_layers: 6
|
28 |
+
input_encoding_size: 512
|
29 |
+
rnn_size: 2048
|
30 |
+
|
31 |
+
# Transformer config
|
32 |
+
N_enc: 6
|
33 |
+
N_dec: 6
|
34 |
+
d_model: 512
|
35 |
+
d_ff: 2048
|
36 |
+
num_att_heads: 8
|
37 |
+
dropout: 0.1
|
38 |
+
|
39 |
+
|
40 |
+
learning_rate_decay_start: 0
|
41 |
+
scheduled_sampling_start: -1
|
42 |
+
save_checkpoint_every: 3000
|
43 |
+
language_eval: 0
|
44 |
+
val_images_use: 5000
|
45 |
+
max_epochs: 15
|
46 |
+
train_sample_n: 5
|
47 |
+
|
48 |
+
REFORWARD: false
|
49 |
+
|
50 |
+
# _BASE_: transformer.yml
|
51 |
+
reduce_on_plateau: false
|
52 |
+
noamopt: false
|
53 |
+
learning_rate: 0.000005
|
54 |
+
learning_rate_decay_start: -1
|
55 |
+
|
56 |
+
self_critical_after: 15
|
57 |
+
max_epochs: 50
|
58 |
+
|
59 |
+
verbose: false
|
60 |
+
precision: 32
|
61 |
+
|
62 |
+
# use_clipscore: true
|
63 |
+
use_clipscore: false
|
64 |
+
clipscore_reward_weight: 2.0
|
configs/phase2/clipRN50_cider.yml
ADDED
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
caption_model: transformer
|
2 |
+
noamopt: true
|
3 |
+
noamopt_warmup: 20000
|
4 |
+
label_smoothing: 0.0
|
5 |
+
input_json: data/cocotalk.json
|
6 |
+
input_label_h5: data/cocotalk_label.h5
|
7 |
+
input_fc_dir: data/cocotalk_clip_RN50_fc
|
8 |
+
input_att_dir: data/cocotalk_clip_RN50_att
|
9 |
+
# used only for evaluation
|
10 |
+
input_clipscore_vis_dir: data/cocotalk_clipscore_vis
|
11 |
+
|
12 |
+
seq_per_img: 5
|
13 |
+
batch_size: 200
|
14 |
+
learning_rate: 0.0005
|
15 |
+
|
16 |
+
# checkpoint_path: ./save/trans_clip_rn50_sc_pl_scst_cider
|
17 |
+
checkpoint_path: save/clipRN50_cider/clipRN50_cider
|
18 |
+
|
19 |
+
# Notice: because I'm to lazy, I reuse the option name for RNNs to set the hyperparameters for transformer:
|
20 |
+
# N=num_layers
|
21 |
+
# d_model=input_encoding_size
|
22 |
+
# d_ff=rnn_size
|
23 |
+
|
24 |
+
# will be ignored
|
25 |
+
num_layers: 6
|
26 |
+
input_encoding_size: 512
|
27 |
+
rnn_size: 2048
|
28 |
+
|
29 |
+
# Transformer config
|
30 |
+
N_enc: 6
|
31 |
+
N_dec: 6
|
32 |
+
d_model: 512
|
33 |
+
d_ff: 2048
|
34 |
+
num_att_heads: 8
|
35 |
+
dropout: 0.1
|
36 |
+
|
37 |
+
|
38 |
+
learning_rate_decay_start: 0
|
39 |
+
scheduled_sampling_start: -1
|
40 |
+
save_checkpoint_every: 3000
|
41 |
+
language_eval: 1
|
42 |
+
val_images_use: 5000
|
43 |
+
max_epochs: 15
|
44 |
+
train_sample_n: 5
|
45 |
+
|
46 |
+
REFORWARD: false
|
47 |
+
|
48 |
+
# _BASE_: transformer.yml
|
49 |
+
reduce_on_plateau: false
|
50 |
+
noamopt: false
|
51 |
+
learning_rate: 0.000005
|
52 |
+
learning_rate_decay_start: -1
|
53 |
+
|
54 |
+
self_critical_after: 15
|
55 |
+
max_epochs: 40
|
56 |
+
|
57 |
+
verbose: false
|
58 |
+
precision: 32
|
configs/phase2/clipRN50_cider_clips.yml
ADDED
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
caption_model: transformer
|
2 |
+
noamopt: true
|
3 |
+
noamopt_warmup: 20000
|
4 |
+
label_smoothing: 0.0
|
5 |
+
input_json: data/cocotalk.json
|
6 |
+
input_label_h5: data/cocotalk_label.h5
|
7 |
+
input_fc_dir: data/cocotalk_clip_RN50_fc
|
8 |
+
input_att_dir: data/cocotalk_clip_RN50_att
|
9 |
+
input_clipscore_vis_dir: data/cocotalk_clipscore_vis
|
10 |
+
seq_per_img: 5
|
11 |
+
batch_size: 160
|
12 |
+
learning_rate: 0.0005
|
13 |
+
|
14 |
+
checkpoint_path: save/clipRN50_cider_clips/clipRN50_cider_clips
|
15 |
+
|
16 |
+
# Notice: because I'm to lazy, I reuse the option name for RNNs to set the hyperparameters for transformer:
|
17 |
+
# N=num_layers
|
18 |
+
# d_model=input_encoding_size
|
19 |
+
# d_ff=rnn_size
|
20 |
+
|
21 |
+
# will be ignored
|
22 |
+
num_layers: 6
|
23 |
+
input_encoding_size: 512
|
24 |
+
rnn_size: 2048
|
25 |
+
|
26 |
+
# Transformer config
|
27 |
+
N_enc: 6
|
28 |
+
N_dec: 6
|
29 |
+
d_model: 512
|
30 |
+
d_ff: 2048
|
31 |
+
num_att_heads: 8
|
32 |
+
dropout: 0.1
|
33 |
+
|
34 |
+
|
35 |
+
learning_rate_decay_start: 0
|
36 |
+
scheduled_sampling_start: -1
|
37 |
+
save_checkpoint_every: 3000
|
38 |
+
language_eval: 1
|
39 |
+
val_images_use: 5000
|
40 |
+
max_epochs: 15
|
41 |
+
train_sample_n: 5
|
42 |
+
|
43 |
+
REFORWARD: false
|
44 |
+
|
45 |
+
# _BASE_: transformer.yml
|
46 |
+
reduce_on_plateau: false
|
47 |
+
noamopt: false
|
48 |
+
learning_rate: 0.000005
|
49 |
+
learning_rate_decay_start: -1
|
50 |
+
|
51 |
+
self_critical_after: 15
|
52 |
+
max_epochs: 40
|
53 |
+
|
54 |
+
verbose: false
|
55 |
+
precision: 32
|
56 |
+
|
57 |
+
use_clipscore: true
|
58 |
+
clipscore_reward_weight: 2.0
|
59 |
+
clipscore_mode: clip_s
|
60 |
+
|
61 |
+
use_multi_rewards: true
|
configs/phase2/clipRN50_clips.yml
ADDED
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
caption_model: transformer
|
2 |
+
noamopt: true
|
3 |
+
noamopt_warmup: 20000
|
4 |
+
label_smoothing: 0.0
|
5 |
+
input_json: data/cocotalk.json
|
6 |
+
input_label_h5: data/cocotalk_label.h5
|
7 |
+
input_fc_dir: data/cocotalk_clip_RN50_fc
|
8 |
+
input_att_dir: data/cocotalk_clip_RN50_att
|
9 |
+
input_clipscore_vis_dir: data/cocotalk_clipscore_vis
|
10 |
+
seq_per_img: 5
|
11 |
+
batch_size: 160
|
12 |
+
learning_rate: 0.0005
|
13 |
+
|
14 |
+
checkpoint_path: save/clipRN50_clips/clipRN50_clips
|
15 |
+
|
16 |
+
# Notice: because I'm to lazy, I reuse the option name for RNNs to set the hyperparameters for transformer:
|
17 |
+
# N=num_layers
|
18 |
+
# d_model=input_encoding_size
|
19 |
+
# d_ff=rnn_size
|
20 |
+
|
21 |
+
# will be ignored
|
22 |
+
num_layers: 6
|
23 |
+
input_encoding_size: 512
|
24 |
+
rnn_size: 2048
|
25 |
+
|
26 |
+
# Transformer config
|
27 |
+
N_enc: 6
|
28 |
+
N_dec: 6
|
29 |
+
d_model: 512
|
30 |
+
d_ff: 2048
|
31 |
+
num_att_heads: 8
|
32 |
+
dropout: 0.1
|
33 |
+
|
34 |
+
|
35 |
+
learning_rate_decay_start: 0
|
36 |
+
scheduled_sampling_start: -1
|
37 |
+
save_checkpoint_every: 3000
|
38 |
+
language_eval: 1
|
39 |
+
val_images_use: 5000
|
40 |
+
max_epochs: 15
|
41 |
+
train_sample_n: 5
|
42 |
+
|
43 |
+
REFORWARD: false
|
44 |
+
|
45 |
+
# _BASE_: transformer.yml
|
46 |
+
reduce_on_plateau: false
|
47 |
+
noamopt: false
|
48 |
+
learning_rate: 0.000005
|
49 |
+
learning_rate_decay_start: -1
|
50 |
+
|
51 |
+
self_critical_after: 15
|
52 |
+
max_epochs: 40
|
53 |
+
|
54 |
+
verbose: false
|
55 |
+
precision: 32
|
56 |
+
|
57 |
+
use_clipscore: true
|
58 |
+
clipscore_reward_weight: 2.0
|
configs/phase2/clipRN50_clips_grammar.yml
ADDED
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
caption_model: transformer
|
2 |
+
noamopt: true
|
3 |
+
noamopt_warmup: 20000
|
4 |
+
label_smoothing: 0.0
|
5 |
+
input_json: data/cocotalk.json
|
6 |
+
input_label_h5: data/cocotalk_label.h5
|
7 |
+
input_fc_dir: data/cocotalk_clip_RN50_fc
|
8 |
+
input_att_dir: data/cocotalk_clip_RN50_att
|
9 |
+
input_clipscore_vis_dir: data/cocotalk_clipscore_vis
|
10 |
+
seq_per_img: 5
|
11 |
+
batch_size: 160
|
12 |
+
learning_rate: 0.0005
|
13 |
+
|
14 |
+
checkpoint_path: save/clipRN50_clips_grammar/clipRN50_clips_grammar
|
15 |
+
|
16 |
+
use_multi_rewards: true
|
17 |
+
use_grammar: true
|
18 |
+
use_grammar_baseline: true
|
19 |
+
# clip_load_path: '/scratch-space/retrieval/save/clip_negative_text/clip_negative_text-epoch=10.ckpt'
|
20 |
+
clip_load_path: 'retrieval/save/clip_negative_text/clip_negative_text-epoch=12.ckpt'
|
21 |
+
|
22 |
+
# Notice: because I'm to lazy, I reuse the option name for RNNs to set the hyperparameters for transformer:
|
23 |
+
# N=num_layers
|
24 |
+
# d_model=input_encoding_size
|
25 |
+
# d_ff=rnn_size
|
26 |
+
|
27 |
+
# will be ignored
|
28 |
+
num_layers: 6
|
29 |
+
input_encoding_size: 512
|
30 |
+
rnn_size: 2048
|
31 |
+
|
32 |
+
# Transformer config
|
33 |
+
N_enc: 6
|
34 |
+
N_dec: 6
|
35 |
+
d_model: 512
|
36 |
+
d_ff: 2048
|
37 |
+
num_att_heads: 8
|
38 |
+
dropout: 0.1
|
39 |
+
|
40 |
+
|
41 |
+
learning_rate_decay_start: 0
|
42 |
+
scheduled_sampling_start: -1
|
43 |
+
save_checkpoint_every: 3000
|
44 |
+
language_eval: 1
|
45 |
+
val_images_use: 5000
|
46 |
+
max_epochs: 15
|
47 |
+
train_sample_n: 5
|
48 |
+
|
49 |
+
REFORWARD: false
|
50 |
+
|
51 |
+
# _BASE_: transformer.yml
|
52 |
+
reduce_on_plateau: false
|
53 |
+
noamopt: false
|
54 |
+
learning_rate: 0.000005
|
55 |
+
learning_rate_decay_start: -1
|
56 |
+
|
57 |
+
self_critical_after: 15
|
58 |
+
max_epochs: 40
|
59 |
+
|
60 |
+
verbose: false
|
61 |
+
precision: 32
|
62 |
+
|
63 |
+
use_clipscore: true
|
64 |
+
clipscore_reward_weight: 2.0
|
configs/phase2/transformer.yml
ADDED
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
caption_model: transformer
|
2 |
+
noamopt: true
|
3 |
+
noamopt_warmup: 20000
|
4 |
+
label_smoothing: 0.0
|
5 |
+
input_json: data/cocotalk.json
|
6 |
+
input_label_h5: data/cocotalk_label.h5
|
7 |
+
input_att_dir: data/cocotalk_att
|
8 |
+
seq_per_img: 5
|
9 |
+
batch_size: 10
|
10 |
+
learning_rate: 0.0005
|
11 |
+
|
12 |
+
checkpoint_path: ./save/trans_rn50_sc
|
13 |
+
|
14 |
+
# Notice: because I'm to lazy, I reuse the option name for RNNs to set the hyperparameters for transformer:
|
15 |
+
# N=num_layers
|
16 |
+
# d_model=input_encoding_size
|
17 |
+
# d_ff=rnn_size
|
18 |
+
|
19 |
+
# will be ignored
|
20 |
+
num_layers: 6
|
21 |
+
input_encoding_size: 512
|
22 |
+
rnn_size: 2048
|
23 |
+
|
24 |
+
# Transformer config
|
25 |
+
N_enc: 6
|
26 |
+
N_dec: 6
|
27 |
+
d_model: 512
|
28 |
+
d_ff: 2048
|
29 |
+
num_att_heads: 8
|
30 |
+
dropout: 0.1
|
31 |
+
|
32 |
+
|
33 |
+
learning_rate_decay_start: 0
|
34 |
+
scheduled_sampling_start: -1
|
35 |
+
save_checkpoint_every: 3000
|
36 |
+
language_eval: 1
|
37 |
+
val_images_use: 5000
|
38 |
+
max_epochs: 15
|
39 |
+
train_sample_n: 5
|
40 |
+
|
41 |
+
REFORWARD: false
|
data/README.md
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
directory to store preprocessed files
|