yashonwu commited on
Commit
9bf9e42
1 Parent(s): ed7b5bc

add captioning

Browse files
captioning/.DS_Store ADDED
Binary file (8.2 kB). View file
 
captioning/__init__.py ADDED
File without changes
captioning/captioner.py ADDED
@@ -0,0 +1,296 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import absolute_import
2
+ from __future__ import division
3
+ from __future__ import print_function
4
+
5
+ import json, math
6
+ import numpy as np
7
+
8
+ import os, sys
9
+ from six.moves import cPickle
10
+
11
+ from sys import path
12
+
13
+ sys.path.insert(0, os.getcwd())
14
+ sys.path.insert(0, 'captioning/')
15
+ # print('relative captioning is called')
16
+
17
+ import captioning.utils.opts as opts
18
+ import captioning.models as models
19
+ from captioning.data.dataloader import *
20
+ from captioning.data.dataloaderraw import *
21
+
22
+ import argparse
23
+ import captioning.utils.misc as utils
24
+ import torch
25
+
26
+ import skimage.io
27
+ from torch.autograd import Variable
28
+ from torchvision import transforms as trn
29
+
30
+ preprocess = trn.Compose([
31
+ # trn.ToTensor(),
32
+ trn.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
33
+ ])
34
+
35
+ from captioning.utils.resnet_utils import myResnet
36
+ from captioning.utils.resnet_utils import ResNetBatch
37
+ import captioning.utils.resnet as resnet
38
+
39
+ import wget
40
+ import tempfile
41
+
42
+ class object:
43
+ def __init__(self):
44
+ self.input_fc_dir = ''
45
+ self.input_json = ''
46
+ self.batch_size = ''
47
+ self.id = ''
48
+ self.sample_max = 1
49
+ self.cnn_model = 'resnet101'
50
+ self.model = ''
51
+ self.language_eval = 0
52
+ self.beam_size = 1
53
+ self.temperature = 1.0
54
+ return
55
+
56
+
57
+ class Captioner():
58
+
59
+ def __init__(self, is_relative=True, model_path=None, image_feat_params=None, data_type=None, load_resnet=True, diff_feat=None):
60
+ opt = object()
61
+
62
+ if image_feat_params==None:
63
+ image_feat_params = {}
64
+ image_feat_params['model'] = 'resnet101'
65
+ image_feat_params['model_root'] = ''
66
+ image_feat_params['att_size'] = 7
67
+
68
+ # inputs specific to shoe dataset
69
+ infos_path = os.path.join(model_path, 'infos_best.pkl')
70
+ model_path = os.path.join(model_path, 'model_best.pth')
71
+
72
+ opt.infos_path = infos_path
73
+ opt.model_path = model_path
74
+ opt.beam_size = 1
75
+ opt.load_resnet = load_resnet
76
+
77
+ # load pre-trained model, adjusting if URL
78
+ if opt.infos_path.startswith("http:") or opt.infos_path.startswith("https:"):
79
+ # create a folder to store the checkpoints for downloading
80
+ if not os.path.exists('./checkpoints_usersim'):
81
+ os.mkdir('./checkpoints_usersim')
82
+
83
+ checkpoint_path = os.path.join('./checkpoints_usersim', data_type)
84
+ if not os.path.exists(checkpoint_path):
85
+ os.mkdir(checkpoint_path)
86
+
87
+ # set the location for infos
88
+ infos_loc = os.path.join(checkpoint_path, 'infos_best.pkl')
89
+
90
+ if not os.path.exists(infos_loc):
91
+ try:
92
+ wget.download(opt.infos_path, infos_loc)
93
+ except Exception as err:
94
+ print(f"[{err}]")
95
+ else:
96
+ infos_loc = infos_path
97
+
98
+ if opt.model_path.startswith("http:") or opt.model_path.startswith("https:"):
99
+ # create a folder to store the checkpoints for downloading
100
+ if not os.path.exists('./checkpoints_usersim'):
101
+ os.mkdir('./checkpoints_usersim')
102
+
103
+ checkpoint_path = os.path.join('./checkpoints_usersim', data_type)
104
+ if not os.path.exists(checkpoint_path):
105
+ os.mkdir(checkpoint_path)
106
+
107
+ # set the location for models
108
+ model_loc = os.path.join(checkpoint_path, 'model_best.pth')
109
+
110
+ if not os.path.exists(model_loc):
111
+ try:
112
+ wget.download(opt.model_path, model_loc)
113
+ except Exception as err:
114
+ print(f"[{err}]")
115
+ opt.model = model_loc
116
+ else:
117
+ opt.model = model_path
118
+
119
+ if os.path.exists(infos_loc):
120
+ # load existing infos
121
+ with open(infos_loc, 'rb') as f:
122
+ infos = cPickle.load(f)
123
+
124
+ self.caption_model = infos["opt"].caption_model
125
+
126
+ # override and collect parameters
127
+ if len(opt.input_fc_dir) == 0:
128
+ opt.input_fc_dir = infos['opt'].input_fc_dir
129
+ opt.input_att_dir = infos['opt'].input_att_dir
130
+ opt.input_label_h5 = infos['opt'].input_label_h5
131
+ if len(opt.input_json) == 0:
132
+ opt.input_json = infos['opt'].input_json
133
+ if opt.batch_size == 0:
134
+ opt.batch_size = infos['opt'].batch_size
135
+ if len(opt.id) == 0:
136
+ opt.id = infos['opt'].id
137
+ ignore = ["id", "batch_size", "beam_size", "start_from", "language_eval", "model"]
138
+ for k in vars(infos['opt']).keys():
139
+ if k not in ignore:
140
+ if k in vars(opt):
141
+ assert vars(opt)[k] == vars(infos['opt'])[k], k + ' option not consistent'
142
+ else:
143
+ vars(opt).update({k: vars(infos['opt'])[k]}) # copy over options from model
144
+
145
+ vocab = infos['vocab'] # ix -> word mapping
146
+
147
+ # print('opt:', opt)
148
+
149
+ # Setup the model
150
+ opt.vocab = vocab
151
+ model = models.setup(opt)
152
+ del opt.vocab
153
+ if torch.cuda.is_available():
154
+ model.load_state_dict(torch.load(opt.model))
155
+ model.cuda()
156
+ else:
157
+ model.load_state_dict(torch.load(opt.model, map_location={'cuda:0': 'cpu'}))
158
+
159
+ model.eval()
160
+
161
+ self.is_relative = is_relative
162
+ self.model = model
163
+ self.vocab = vocab
164
+ self.opt = vars(opt)
165
+
166
+ # Load ResNet for processing images
167
+ if opt.load_resnet:
168
+ if image_feat_params['model_root']=='':
169
+ net = getattr(resnet, image_feat_params['model'])(pretrained=True)
170
+ else:
171
+ net = getattr(resnet, image_feat_params['model'])()
172
+ net.load_state_dict(
173
+ torch.load(os.path.join(image_feat_params['model_root'], image_feat_params['model'] + '.pth')))
174
+ my_resnet = myResnet(net)
175
+ if torch.cuda.is_available():
176
+ my_resnet.cuda()
177
+ my_resnet.eval()
178
+
179
+ my_resnet_batch = ResNetBatch(net)
180
+ if torch.cuda.is_available():
181
+ my_resnet_batch.cuda()
182
+
183
+ self.my_resnet_batch = my_resnet_batch
184
+ self.my_resnet = my_resnet
185
+ self.att_size = image_feat_params['att_size']
186
+
187
+ # Control the input features of the model
188
+ if diff_feat == None:
189
+ if self.caption_model == "show_attend_tell":
190
+ self.diff_feat = True
191
+ else:
192
+ self.diff_feat = False
193
+ else:
194
+ self.diff_feat = diff_feat
195
+
196
+ def gen_caption_from_feat(self, feat_target, feat_reference=None):
197
+ if self.is_relative and feat_reference == None:
198
+ return None, None
199
+
200
+ if not self.is_relative and not feat_reference == None:
201
+ return None, None
202
+
203
+ if self.is_relative:
204
+ if self.diff_feat:
205
+ fc_feat = torch.cat((feat_target[0], feat_target[0] - feat_reference[0]), dim=-1)
206
+ att_feat = torch.cat((feat_target[1], feat_target[1] - feat_reference[1]), dim=-1)
207
+ else:
208
+ fc_feat = torch.cat((feat_target[0], feat_reference[0]), dim=-1)
209
+ att_feat = torch.cat((feat_target[1], feat_reference[1]), dim=-1)
210
+ else:
211
+ fc_feat = feat_target[0]
212
+ att_feat = feat_target[1]
213
+
214
+ # Reshape to B x K x C (128,14,14,4096) --> (128,196,4096)
215
+ att_feat = att_feat.view(att_feat.shape[0], att_feat.shape[1] * att_feat.shape[2], att_feat.shape[-1])
216
+
217
+ att_masks = np.zeros(att_feat.shape[:2], dtype='float32')
218
+ for i in range(len(att_feat)):
219
+ att_masks[i, :att_feat[i].shape[0]] = 1
220
+ # set att_masks to None if attention features have same length
221
+ if att_masks.sum() == att_masks.size:
222
+ att_masks = None
223
+
224
+ if self.caption_model == 'show_attend_tell':
225
+ seq, _ = self.model.sample(fc_feat, att_feat, self.opt)
226
+ else:
227
+ seq, _ = self.model(fc_feat, att_feat, att_masks=att_masks, opt=self.opt, mode='sample')
228
+ sents = utils.decode_sequence(self.vocab, seq)
229
+
230
+ return seq, sents
231
+
232
+ def get_vocab_size(self):
233
+ return len(self.vocab)
234
+
235
+ def get_img_feat(self, img_name):
236
+ # load the image
237
+ I = skimage.io.imread(img_name)
238
+
239
+ if len(I.shape) == 2:
240
+ I = I[:, :, np.newaxis]
241
+ I = np.concatenate((I, I, I), axis=2)
242
+
243
+ I = I.astype('float32') / 255.0
244
+ I = torch.from_numpy(I.transpose([2, 0, 1]))
245
+ if torch.cuda.is_available(): I = I.cuda()
246
+ # I = Variable(preprocess(I), volatile=True)
247
+ with torch.no_grad():
248
+ I = preprocess(I)
249
+ fc, att = self.my_resnet(I, self.att_size)
250
+
251
+ return fc, att
252
+
253
+ def get_img_feat_batch(self, img_names, batchsize=32):
254
+ if not isinstance(img_names, list):
255
+ img_names = [img_names]
256
+
257
+ num_images = len(img_names)
258
+ num_batches = math.ceil(np.float(num_images) / np.float(batchsize))
259
+
260
+ feature_fc = []
261
+ feature_att = []
262
+
263
+ for id in range(num_batches):
264
+ startInd = id * batchsize
265
+ endInd = min((id + 1) * batchsize, num_images)
266
+
267
+ img_names_current_batch = img_names[startInd:endInd]
268
+ I_current_batch = []
269
+
270
+ for img_name in img_names_current_batch:
271
+ I = skimage.io.imread(img_name)
272
+
273
+ if len(I.shape) == 2:
274
+ I = I[:, :, np.newaxis]
275
+ I = np.concatenate((I, I, I), axis=2)
276
+
277
+ I = I.astype('float32') / 255.0
278
+ I = torch.from_numpy(I.transpose([2, 0, 1]))
279
+ I_current_batch.append(preprocess(I))
280
+
281
+ I_current_batch = torch.stack(I_current_batch, dim=0)
282
+ if torch.cuda.is_available(): I_current_batch = I_current_batch.cuda()
283
+ # I_current_batch = Variable(I_current_batch, volatile=True)
284
+ with torch.no_grad():
285
+ fc, att = self.my_resnet_batch(I_current_batch, self.att_size)
286
+
287
+ feature_fc.append(fc)
288
+ feature_att.append(att)
289
+
290
+ feature_fc = torch.cat(feature_fc, dim=0)
291
+ feature_att = torch.cat(feature_att, dim=0)
292
+
293
+ return feature_fc, feature_att
294
+
295
+
296
+
captioning/data/.DS_Store ADDED
Binary file (6.15 kB). View file
 
captioning/data/__init__.py ADDED
File without changes
captioning/data/dataloader.py ADDED
@@ -0,0 +1,439 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import absolute_import
2
+ from __future__ import division
3
+ from __future__ import print_function
4
+
5
+ import json
6
+ import h5py
7
+ from lmdbdict import lmdbdict
8
+ from lmdbdict.methods import DUMPS_FUNC, LOADS_FUNC
9
+ import os
10
+ import numpy as np
11
+ import numpy.random as npr
12
+ import random
13
+ from functools import partial
14
+
15
+ import torch
16
+ import torch.utils.data as data
17
+
18
+ import multiprocessing
19
+ import six
20
+
21
+
22
+ class HybridLoader:
23
+ """
24
+ If db_path is a director, then use normal file loading
25
+ If lmdb, then load from lmdb
26
+ The loading method depend on extention.
27
+
28
+ in_memory: if in_memory is True, we save all the features in memory
29
+ For individual np(y|z)s, we don't need to do that because the system will do this for us.
30
+ Should be useful for lmdb or h5.
31
+ (Copied this idea from vilbert)
32
+ """
33
+ def __init__(self, db_path, ext, in_memory=False):
34
+ self.db_path = db_path
35
+ self.ext = ext
36
+ if self.ext == '.npy':
37
+ self.loader = lambda x: np.load(six.BytesIO(x))
38
+ else:
39
+ def load_npz(x):
40
+ x = np.load(six.BytesIO(x))
41
+ return x['feat'] if 'feat' in x else x['z'] # normally it should be 'feat', but under cocotest_bu, the key is saved to be 'z' mistakenly.
42
+ self.loader = load_npz
43
+ if db_path.endswith('.lmdb'):
44
+ self.db_type = 'lmdb'
45
+ self.lmdb = lmdbdict(db_path, unsafe=True)
46
+ self.lmdb._key_dumps = DUMPS_FUNC['ascii']
47
+ self.lmdb._value_loads = LOADS_FUNC['identity']
48
+ elif db_path.endswith('.pth'): # Assume a key,value dictionary
49
+ self.db_type = 'pth'
50
+ self.feat_file = torch.load(db_path)
51
+ self.loader = lambda x: x
52
+ print('HybridLoader: ext is ignored')
53
+ elif db_path.endswith('h5'):
54
+ self.db_type = 'h5'
55
+ self.loader = lambda x: np.array(x).astype('float32')
56
+ else:
57
+ self.db_type = 'dir'
58
+
59
+ self.in_memory = in_memory
60
+ if self.in_memory:
61
+ self.features = {}
62
+
63
+ def get(self, key):
64
+
65
+ if self.in_memory and key in self.features:
66
+ # We save f_input because we want to save the
67
+ # compressed bytes to save memory
68
+ f_input = self.features[key]
69
+ elif self.db_type == 'lmdb':
70
+ f_input = self.lmdb[key]
71
+ elif self.db_type == 'pth':
72
+ f_input = self.feat_file[key]
73
+ elif self.db_type == 'h5':
74
+ f_input = h5py.File(self.db_path, 'r')[key]
75
+ else:
76
+ f_input = open(os.path.join(self.db_path, key + self.ext), 'rb').read()
77
+
78
+ if self.in_memory and key not in self.features:
79
+ self.features[key] = f_input
80
+
81
+ # load image
82
+ feat = self.loader(f_input)
83
+
84
+ return feat
85
+
86
+ class Dataset(data.Dataset):
87
+
88
+ def get_vocab_size(self):
89
+ return self.vocab_size
90
+
91
+ def get_vocab(self):
92
+ return self.ix_to_word
93
+
94
+ def get_seq_length(self):
95
+ return self.seq_length
96
+
97
+ def __init__(self, opt):
98
+ self.opt = opt
99
+ self.seq_per_img = opt.seq_per_img
100
+
101
+ # feature related options
102
+ self.use_fc = getattr(opt, 'use_fc', True)
103
+ self.use_att = getattr(opt, 'use_att', True)
104
+ self.use_box = getattr(opt, 'use_box', 0)
105
+ self.norm_att_feat = getattr(opt, 'norm_att_feat', 0)
106
+ self.norm_box_feat = getattr(opt, 'norm_box_feat', 0)
107
+
108
+ # load the json file which contains additional information about the dataset
109
+ print('DataLoader loading json file: ', opt.input_json)
110
+ self.info = json.load(open(self.opt.input_json))
111
+ if 'ix_to_word' in self.info:
112
+ self.ix_to_word = self.info['ix_to_word']
113
+ self.vocab_size = len(self.ix_to_word)
114
+ print('vocab size is ', self.vocab_size)
115
+
116
+ # open the hdf5 file
117
+ print('DataLoader loading h5 file: ', opt.input_fc_dir, opt.input_att_dir, opt.input_box_dir, opt.input_label_h5)
118
+ """
119
+ Setting input_label_h5 to none is used when only doing generation.
120
+ For example, when you need to test on coco test set.
121
+ """
122
+ if self.opt.input_label_h5 != 'none':
123
+ self.h5_label_file = h5py.File(self.opt.input_label_h5, 'r', driver='core')
124
+ # load in the sequence data
125
+ seq_size = self.h5_label_file['labels'].shape
126
+ self.label = self.h5_label_file['labels'][:]
127
+ self.seq_length = seq_size[1]
128
+ print('max sequence length in data is', self.seq_length)
129
+ # load the pointers in full to RAM (should be small enough)
130
+ self.label_start_ix = self.h5_label_file['label_start_ix'][:]
131
+ self.label_end_ix = self.h5_label_file['label_end_ix'][:]
132
+ else:
133
+ self.seq_length = 1
134
+
135
+ self.data_in_memory = getattr(opt, 'data_in_memory', False)
136
+ self.fc_loader = HybridLoader(self.opt.input_fc_dir, '.npy', in_memory=self.data_in_memory)
137
+ self.att_loader = HybridLoader(self.opt.input_att_dir, '.npz', in_memory=self.data_in_memory)
138
+ self.box_loader = HybridLoader(self.opt.input_box_dir, '.npy', in_memory=self.data_in_memory)
139
+
140
+ self.num_images = len(self.info['images']) # self.label_start_ix.shape[0]
141
+ print('read %d image features' %(self.num_images))
142
+
143
+ # separate out indexes for each of the provided splits
144
+ self.split_ix = {'train': [], 'val': [], 'test': []}
145
+ for ix in range(len(self.info['images'])):
146
+ img = self.info['images'][ix]
147
+ if not 'split' in img:
148
+ self.split_ix['train'].append(ix)
149
+ self.split_ix['val'].append(ix)
150
+ self.split_ix['test'].append(ix)
151
+ elif img['split'] == 'train':
152
+ self.split_ix['train'].append(ix)
153
+ elif img['split'] == 'val':
154
+ self.split_ix['val'].append(ix)
155
+ elif img['split'] == 'test':
156
+ self.split_ix['test'].append(ix)
157
+ elif opt.train_only == 0: # restval
158
+ self.split_ix['train'].append(ix)
159
+
160
+ print('assigned %d images to split train' %len(self.split_ix['train']))
161
+ print('assigned %d images to split val' %len(self.split_ix['val']))
162
+ print('assigned %d images to split test' %len(self.split_ix['test']))
163
+
164
+ def get_captions(self, ix, seq_per_img):
165
+ # fetch the sequence labels
166
+ ix1 = self.label_start_ix[ix] - 1 #label_start_ix starts from 1
167
+ ix2 = self.label_end_ix[ix] - 1
168
+ ncap = ix2 - ix1 + 1 # number of captions available for this image
169
+ assert ncap > 0, 'an image does not have any label. this can be handled but right now isn\'t'
170
+
171
+ random.seed(42)
172
+ torch.manual_seed(42)
173
+ if torch.cuda.is_available():
174
+ torch.cuda.manual_seed(42)
175
+
176
+ if ncap < seq_per_img:
177
+ # we need to subsample (with replacement)
178
+ seq = np.zeros([seq_per_img, self.seq_length], dtype = 'int')
179
+ for q in range(seq_per_img):
180
+ ixl = random.randint(ix1,ix2)
181
+ seq[q, :] = self.label[ixl, :self.seq_length]
182
+ else:
183
+ ixl = random.randint(ix1, ix2 - seq_per_img + 1)
184
+ seq = self.label[ixl: ixl + seq_per_img, :self.seq_length]
185
+
186
+ return seq
187
+
188
+ def collate_func(self, batch, split):
189
+ seq_per_img = self.seq_per_img
190
+
191
+ fc_batch = []
192
+ att_batch = []
193
+ label_batch = []
194
+
195
+ wrapped = False
196
+
197
+ infos = []
198
+ gts = []
199
+
200
+ for sample in batch:
201
+ # fetch image
202
+ tmp_fc, tmp_att, tmp_seq, \
203
+ ix, it_pos_now, tmp_wrapped = sample
204
+ if tmp_wrapped:
205
+ wrapped = True
206
+
207
+ fc_batch.append(tmp_fc)
208
+ att_batch.append(tmp_att)
209
+
210
+ tmp_label = np.zeros([seq_per_img, self.seq_length + 2], dtype = 'int')
211
+ if hasattr(self, 'h5_label_file'):
212
+ # if there is ground truth
213
+ tmp_label[:, 1 : self.seq_length + 1] = tmp_seq
214
+ label_batch.append(tmp_label)
215
+
216
+ # Used for reward evaluation
217
+ if hasattr(self, 'h5_label_file'):
218
+ # if there is ground truth
219
+ gts.append(self.label[self.label_start_ix[ix] - 1: self.label_end_ix[ix]])
220
+ else:
221
+ gts.append([])
222
+
223
+ # record associated info as well
224
+ info_dict = {}
225
+ info_dict['ix'] = ix
226
+ info_dict['id'] = self.info['images'][ix]['id']
227
+ info_dict['file_path'] = self.info['images'][ix].get('file_path', '')
228
+ infos.append(info_dict)
229
+
230
+ # #sort by att_feat length
231
+ # fc_batch, att_batch, label_batch, gts, infos = \
232
+ # zip(*sorted(zip(fc_batch, att_batch, np.vsplit(label_batch, batch_size), gts, infos), key=lambda x: len(x[1]), reverse=True))
233
+ fc_batch, att_batch, label_batch, gts, infos = \
234
+ zip(*sorted(zip(fc_batch, att_batch, label_batch, gts, infos), key=lambda x: 0, reverse=True))
235
+
236
+ data = {}
237
+ data['fc_feats'] = np.stack(fc_batch)
238
+ # merge att_feats
239
+ max_att_len = max([_.shape[0] for _ in att_batch])
240
+ data['att_feats'] = np.zeros([len(att_batch), max_att_len, att_batch[0].shape[1]], dtype = 'float32')
241
+
242
+ for i in range(len(att_batch)):
243
+ data['att_feats'][i, :att_batch[i].shape[0]] = att_batch[i]
244
+
245
+ data['att_masks'] = np.zeros(data['att_feats'].shape[:2], dtype='float32')
246
+ for i in range(len(att_batch)):
247
+ data['att_masks'][i, :att_batch[i].shape[0]] = 1
248
+ # set att_masks to None if attention features have same length
249
+ if data['att_masks'].sum() == data['att_masks'].size:
250
+ data['att_masks'] = None
251
+
252
+ data['labels'] = np.vstack(label_batch)
253
+ # generate mask
254
+ nonzeros = np.array(list(map(lambda x: (x != 0).sum()+2, data['labels'])))
255
+ mask_batch = np.zeros([data['labels'].shape[0], self.seq_length + 2], dtype = 'float32')
256
+ for ix, row in enumerate(mask_batch):
257
+ row[:nonzeros[ix]] = 1
258
+ data['masks'] = mask_batch
259
+ data['labels'] = data['labels'].reshape(len(batch), seq_per_img, -1)
260
+ data['masks'] = data['masks'].reshape(len(batch), seq_per_img, -1)
261
+
262
+ data['gts'] = gts # all ground truth captions of each images
263
+ data['bounds'] = {'it_pos_now': it_pos_now, # the it_pos_now of the last sample
264
+ 'it_max': len(self.split_ix[split]), 'wrapped': wrapped}
265
+ data['infos'] = infos
266
+
267
+ data = {k:torch.from_numpy(v) if type(v) is np.ndarray else v for k,v in data.items()} # Turn all ndarray to torch tensor
268
+
269
+ return data
270
+
271
+ def __getitem__(self, index):
272
+ """This function returns a tuple that is further passed to collate_fn
273
+ """
274
+ ix, it_pos_now, wrapped = index #self.split_ix[index]
275
+ if self.use_att:
276
+ att_feat = self.att_loader.get(str(self.info['images'][ix]['id']))
277
+ # shape: (14,14,4096)
278
+
279
+ # Reshape to K x C
280
+ att_feat = att_feat.reshape(-1, att_feat.shape[-1])
281
+ # shape:(196,4096)
282
+
283
+ if self.norm_att_feat:
284
+ att_feat = att_feat / np.linalg.norm(att_feat, 2, 1, keepdims=True)
285
+ if self.use_box:
286
+ box_feat = self.box_loader.get(str(self.info['images'][ix]['id']))
287
+ # devided by image width and height
288
+ x1,y1,x2,y2 = np.hsplit(box_feat, 4)
289
+ h,w = self.info['images'][ix]['height'], self.info['images'][ix]['width']
290
+ box_feat = np.hstack((x1/w, y1/h, x2/w, y2/h, (x2-x1)*(y2-y1)/(w*h))) # question? x2-x1+1??
291
+ if self.norm_box_feat:
292
+ box_feat = box_feat / np.linalg.norm(box_feat, 2, 1, keepdims=True)
293
+ att_feat = np.hstack([att_feat, box_feat])
294
+ # sort the features by the size of boxes
295
+ att_feat = np.stack(sorted(att_feat, key=lambda x:x[-1], reverse=True))
296
+ else:
297
+ att_feat = np.zeros((0,0), dtype='float32')
298
+ if self.use_fc:
299
+ try:
300
+ fc_feat = self.fc_loader.get(str(self.info['images'][ix]['id']))
301
+ except:
302
+ # Use average of attention when there is no fc provided (For bottomup feature)
303
+ fc_feat = att_feat.mean(0)
304
+ else:
305
+ fc_feat = np.zeros((0), dtype='float32')
306
+ if hasattr(self, 'h5_label_file'):
307
+ seq = self.get_captions(ix, self.seq_per_img)
308
+ else:
309
+ seq = None
310
+ return (fc_feat,
311
+ att_feat, seq,
312
+ ix, it_pos_now, wrapped)
313
+
314
+ def __len__(self):
315
+ return len(self.info['images'])
316
+
317
+ class DataLoader:
318
+ def __init__(self, opt):
319
+ self.opt = opt
320
+ self.batch_size = self.opt.batch_size
321
+ self.dataset = Dataset(opt)
322
+
323
+ # Initialize loaders and iters
324
+ self.loaders, self.iters = {}, {}
325
+ for split in ['train', 'val', 'test']:
326
+ if split == 'train':
327
+ sampler = MySampler(self.dataset.split_ix[split], shuffle=True, wrap=True)
328
+ else:
329
+ sampler = MySampler(self.dataset.split_ix[split], shuffle=False, wrap=False)
330
+ self.loaders[split] = data.DataLoader(dataset=self.dataset,
331
+ batch_size=self.batch_size,
332
+ sampler=sampler,
333
+ pin_memory=True,
334
+ num_workers=4, # 4 is usually enough
335
+ collate_fn=partial(self.dataset.collate_func, split=split),
336
+ drop_last=False)
337
+ self.iters[split] = iter(self.loaders[split])
338
+
339
+ def get_batch(self, split):
340
+ try:
341
+ data = next(self.iters[split])
342
+ except StopIteration:
343
+ self.iters[split] = iter(self.loaders[split])
344
+ data = next(self.iters[split])
345
+ return data
346
+
347
+ def reset_iterator(self, split):
348
+ self.loaders[split].sampler._reset_iter()
349
+ self.iters[split] = iter(self.loaders[split])
350
+
351
+ def get_vocab_size(self):
352
+ return self.dataset.get_vocab_size()
353
+
354
+ @property
355
+ def vocab_size(self):
356
+ return self.get_vocab_size()
357
+
358
+ def get_vocab(self):
359
+ return self.dataset.get_vocab()
360
+
361
+ def get_seq_length(self):
362
+ return self.dataset.get_seq_length()
363
+
364
+ @property
365
+ def seq_length(self):
366
+ return self.get_seq_length()
367
+
368
+ def state_dict(self):
369
+ def get_prefetch_num(split):
370
+ if self.loaders[split].num_workers > 0:
371
+ return (self.iters[split]._send_idx - self.iters[split]._rcvd_idx) * self.batch_size
372
+ else:
373
+ return 0
374
+ return {split: loader.sampler.state_dict(get_prefetch_num(split)) \
375
+ for split, loader in self.loaders.items()}
376
+
377
+ def load_state_dict(self, state_dict=None):
378
+ if state_dict is None:
379
+ return
380
+ for split in self.loaders.keys():
381
+ self.loaders[split].sampler.load_state_dict(state_dict[split])
382
+
383
+
384
+ class MySampler(data.sampler.Sampler):
385
+ def __init__(self, index_list, shuffle, wrap):
386
+ self.index_list = index_list
387
+ self.shuffle = shuffle
388
+ self.wrap = wrap
389
+ # if wrap, there will be not stop iteration called
390
+ # wrap True used during training, and wrap False used during test.
391
+ self._reset_iter()
392
+
393
+ def __iter__(self):
394
+ return self
395
+
396
+ def __next__(self):
397
+ wrapped = False
398
+ if self.iter_counter == len(self._index_list):
399
+ self._reset_iter()
400
+ if self.wrap:
401
+ wrapped = True
402
+ else:
403
+ raise StopIteration()
404
+ if len(self._index_list) == 0: # overflow when 0 samples
405
+ return None
406
+ elem = (self._index_list[self.iter_counter], self.iter_counter+1, wrapped)
407
+ self.iter_counter += 1
408
+ return elem
409
+
410
+ def next(self):
411
+ return self.__next__()
412
+
413
+ def _reset_iter(self):
414
+ np.random.seed(42)
415
+ if self.shuffle:
416
+ rand_perm = npr.permutation(len(self.index_list))
417
+ self._index_list = [self.index_list[_] for _ in rand_perm]
418
+ else:
419
+ self._index_list = self.index_list
420
+
421
+ self.iter_counter = 0
422
+
423
+ def __len__(self):
424
+ return len(self.index_list)
425
+
426
+ def load_state_dict(self, state_dict=None):
427
+ if state_dict is None:
428
+ return
429
+ self._index_list = state_dict['index_list']
430
+ self.iter_counter = state_dict['iter_counter']
431
+
432
+ def state_dict(self, prefetched_num=None):
433
+ prefetched_num = prefetched_num or 0
434
+ return {
435
+ 'index_list': self._index_list,
436
+ 'iter_counter': self.iter_counter - prefetched_num
437
+ }
438
+
439
+
captioning/data/dataloader_recsys.py ADDED
@@ -0,0 +1,432 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import absolute_import
2
+ from __future__ import division
3
+ from __future__ import print_function
4
+
5
+ import json
6
+ import h5py
7
+ from lmdbdict import lmdbdict
8
+ from lmdbdict.methods import DUMPS_FUNC, LOADS_FUNC
9
+ import os
10
+ import numpy as np
11
+ import numpy.random as npr
12
+ import random
13
+ from functools import partial
14
+
15
+ import torch
16
+ import torch.utils.data as data
17
+
18
+ import multiprocessing
19
+ import six
20
+
21
+
22
+ class HybridLoader:
23
+ """
24
+ If db_path is a director, then use normal file loading
25
+ If lmdb, then load from lmdb
26
+ The loading method depend on extention.
27
+
28
+ in_memory: if in_memory is True, we save all the features in memory
29
+ For individual np(y|z)s, we don't need to do that because the system will do this for us.
30
+ Should be useful for lmdb or h5.
31
+ (Copied this idea from vilbert)
32
+ """
33
+ def __init__(self, db_path, ext, in_memory=False):
34
+ self.db_path = db_path
35
+ self.ext = ext
36
+ if self.ext == '.npy':
37
+ self.loader = lambda x: np.load(six.BytesIO(x))
38
+ else:
39
+ def load_npz(x):
40
+ x = np.load(six.BytesIO(x))
41
+ return x['feat'] if 'feat' in x else x['z'] # normally it should be 'feat', but under cocotest_bu, the key is saved to be 'z' mistakenly.
42
+ self.loader = load_npz
43
+ if db_path.endswith('.lmdb'):
44
+ self.db_type = 'lmdb'
45
+ self.lmdb = lmdbdict(db_path, unsafe=True)
46
+ self.lmdb._key_dumps = DUMPS_FUNC['ascii']
47
+ self.lmdb._value_loads = LOADS_FUNC['identity']
48
+ elif db_path.endswith('.pth'): # Assume a key,value dictionary
49
+ self.db_type = 'pth'
50
+ self.feat_file = torch.load(db_path)
51
+ self.loader = lambda x: x
52
+ print('HybridLoader: ext is ignored')
53
+ elif db_path.endswith('h5'):
54
+ self.db_type = 'h5'
55
+ self.loader = lambda x: np.array(x).astype('float32')
56
+ else:
57
+ self.db_type = 'dir'
58
+
59
+ self.in_memory = in_memory
60
+ if self.in_memory:
61
+ self.features = {}
62
+
63
+ def get(self, key):
64
+
65
+ if self.in_memory and key in self.features:
66
+ # We save f_input because we want to save the
67
+ # compressed bytes to save memory
68
+ f_input = self.features[key]
69
+ elif self.db_type == 'lmdb':
70
+ f_input = self.lmdb[key]
71
+ elif self.db_type == 'pth':
72
+ f_input = self.feat_file[key]
73
+ elif self.db_type == 'h5':
74
+ f_input = h5py.File(self.db_path, 'r')[key]
75
+ else:
76
+ f_input = open(os.path.join(self.db_path, key + self.ext), 'rb').read()
77
+
78
+ if self.in_memory and key not in self.features:
79
+ self.features[key] = f_input
80
+
81
+ # load image
82
+ feat = self.loader(f_input)
83
+
84
+ return feat
85
+
86
+ class Dataset(data.Dataset):
87
+
88
+ def get_vocab_size(self):
89
+ return self.vocab_size
90
+
91
+ def get_vocab(self):
92
+ return self.ix_to_word
93
+
94
+ def get_seq_length(self):
95
+ return self.seq_length
96
+
97
+ def __init__(self, opt):
98
+ self.opt = opt
99
+ self.seq_per_img = opt.seq_per_img
100
+
101
+ # feature related options
102
+ self.use_fc = getattr(opt, 'use_fc', True)
103
+ self.use_att = getattr(opt, 'use_att', True)
104
+ self.use_box = getattr(opt, 'use_box', 0)
105
+ self.norm_att_feat = getattr(opt, 'norm_att_feat', 0)
106
+ self.norm_box_feat = getattr(opt, 'norm_box_feat', 0)
107
+
108
+ # load the json file which contains additional information about the dataset
109
+ print('DataLoader loading json file: ', opt.input_json)
110
+ self.info = json.load(open(self.opt.input_json))
111
+ if 'ix_to_word' in self.info:
112
+ self.ix_to_word = self.info['ix_to_word']
113
+ self.vocab_size = len(self.ix_to_word)
114
+ print('vocab size is ', self.vocab_size)
115
+
116
+ # open the hdf5 file
117
+ print('DataLoader loading h5 file: ', opt.input_fc_dir, opt.input_att_dir, opt.input_box_dir, opt.input_label_h5)
118
+ """
119
+ Setting input_label_h5 to none is used when only doing generation.
120
+ For example, when you need to test on coco test set.
121
+ """
122
+ if self.opt.input_label_h5 != 'none':
123
+ self.h5_label_file = h5py.File(self.opt.input_label_h5, 'r', driver='core')
124
+ # load in the sequence data
125
+ seq_size = self.h5_label_file['labels'].shape
126
+ self.label = self.h5_label_file['labels'][:]
127
+ self.seq_length = seq_size[1]
128
+ print('max sequence length in data is', self.seq_length)
129
+ # load the pointers in full to RAM (should be small enough)
130
+ self.label_start_ix = self.h5_label_file['label_start_ix'][:]
131
+ self.label_end_ix = self.h5_label_file['label_end_ix'][:]
132
+ else:
133
+ self.seq_length = 1
134
+
135
+ self.data_in_memory = getattr(opt, 'data_in_memory', False)
136
+ self.fc_loader = HybridLoader(self.opt.input_fc_dir, '.npy', in_memory=self.data_in_memory)
137
+ self.att_loader = HybridLoader(self.opt.input_att_dir, '.npz', in_memory=self.data_in_memory)
138
+ self.box_loader = HybridLoader(self.opt.input_box_dir, '.npy', in_memory=self.data_in_memory)
139
+
140
+ self.num_images = len(self.info['images']) # self.label_start_ix.shape[0]
141
+ print('read %d image features' %(self.num_images))
142
+
143
+ # separate out indexes for each of the provided splits
144
+ self.split_ix = {'train': [], 'val': [], 'test': []}
145
+ for ix in range(len(self.info['images'])):
146
+ img = self.info['images'][ix]
147
+ if not 'split' in img:
148
+ self.split_ix['train'].append(ix)
149
+ self.split_ix['val'].append(ix)
150
+ self.split_ix['test'].append(ix)
151
+ elif img['split'] == 'train':
152
+ self.split_ix['train'].append(ix)
153
+ elif img['split'] == 'val':
154
+ self.split_ix['val'].append(ix)
155
+ elif img['split'] == 'test':
156
+ self.split_ix['test'].append(ix)
157
+ elif opt.train_only == 0: # restval
158
+ self.split_ix['train'].append(ix)
159
+
160
+ print('assigned %d images to split train' %len(self.split_ix['train']))
161
+ print('assigned %d images to split val' %len(self.split_ix['val']))
162
+ print('assigned %d images to split test' %len(self.split_ix['test']))
163
+
164
+ def get_captions(self, ix, seq_per_img):
165
+ # fetch the sequence labels
166
+ ix1 = self.label_start_ix[ix] - 1 #label_start_ix starts from 1
167
+ ix2 = self.label_end_ix[ix] - 1
168
+ ncap = ix2 - ix1 + 1 # number of captions available for this image
169
+ assert ncap > 0, 'an image does not have any label. this can be handled but right now isn\'t'
170
+
171
+ random.seed(42)
172
+ torch.manual_seed(42)
173
+ if torch.cuda.is_available():
174
+ torch.cuda.manual_seed(42)
175
+
176
+ if ncap < seq_per_img:
177
+ # we need to subsample (with replacement)
178
+ seq = np.zeros([seq_per_img, self.seq_length], dtype = 'int')
179
+ for q in range(seq_per_img):
180
+ ixl = random.randint(ix1,ix2)
181
+ seq[q, :] = self.label[ixl, :self.seq_length]
182
+ else:
183
+ ixl = random.randint(ix1, ix2 - seq_per_img + 1)
184
+ seq = self.label[ixl: ixl + seq_per_img, :self.seq_length]
185
+
186
+ return seq
187
+
188
+ def collate_func(self, batch, split):
189
+ seq_per_img = self.seq_per_img
190
+
191
+ fc_batch = []
192
+ att_batch = []
193
+ label_batch = []
194
+
195
+ wrapped = False
196
+
197
+ infos = []
198
+ gts = []
199
+
200
+ for sample in batch:
201
+ # fetch image
202
+ tmp_fc, tmp_att, tmp_seq, \
203
+ ix, it_pos_now, tmp_wrapped = sample
204
+ if tmp_wrapped:
205
+ wrapped = True
206
+
207
+ fc_batch.append(tmp_fc)
208
+ att_batch.append(tmp_att)
209
+
210
+ tmp_label = np.zeros([seq_per_img, self.seq_length + 2], dtype = 'int')
211
+ if hasattr(self, 'h5_label_file'):
212
+ # if there is ground truth
213
+ tmp_label[:, 1 : self.seq_length + 1] = tmp_seq
214
+ label_batch.append(tmp_label)
215
+
216
+ # Used for reward evaluation
217
+ if hasattr(self, 'h5_label_file'):
218
+ # if there is ground truth
219
+ gts.append(self.label[self.label_start_ix[ix] - 1: self.label_end_ix[ix]])
220
+ else:
221
+ gts.append([])
222
+
223
+ # record associated info as well
224
+ info_dict = {}
225
+ info_dict['ix'] = ix
226
+ info_dict['id'] = self.info['images'][ix]['id']
227
+ info_dict['file_path'] = self.info['images'][ix].get('file_path', '')
228
+ infos.append(info_dict)
229
+
230
+ # #sort by att_feat length
231
+ # fc_batch, att_batch, label_batch, gts, infos = \
232
+ # zip(*sorted(zip(fc_batch, att_batch, np.vsplit(label_batch, batch_size), gts, infos), key=lambda x: len(x[1]), reverse=True))
233
+ fc_batch, att_batch, label_batch, gts, infos = \
234
+ zip(*sorted(zip(fc_batch, att_batch, label_batch, gts, infos), key=lambda x: 0, reverse=True))
235
+ data = {}
236
+ data['fc_feats'] = np.stack(fc_batch)
237
+ # merge att_feats
238
+ max_att_len = max([_.shape[0] for _ in att_batch])
239
+ data['att_feats'] = np.zeros([len(att_batch), max_att_len, att_batch[0].shape[1]], dtype = 'float32')
240
+ for i in range(len(att_batch)):
241
+ data['att_feats'][i, :att_batch[i].shape[0]] = att_batch[i]
242
+ data['att_masks'] = np.zeros(data['att_feats'].shape[:2], dtype='float32')
243
+ for i in range(len(att_batch)):
244
+ data['att_masks'][i, :att_batch[i].shape[0]] = 1
245
+ # set att_masks to None if attention features have same length
246
+ if data['att_masks'].sum() == data['att_masks'].size:
247
+ data['att_masks'] = None
248
+
249
+ data['labels'] = np.vstack(label_batch)
250
+ # generate mask
251
+ nonzeros = np.array(list(map(lambda x: (x != 0).sum()+2, data['labels'])))
252
+ mask_batch = np.zeros([data['labels'].shape[0], self.seq_length + 2], dtype = 'float32')
253
+ for ix, row in enumerate(mask_batch):
254
+ row[:nonzeros[ix]] = 1
255
+ data['masks'] = mask_batch
256
+ data['labels'] = data['labels'].reshape(len(batch), seq_per_img, -1)
257
+ data['masks'] = data['masks'].reshape(len(batch), seq_per_img, -1)
258
+
259
+ data['gts'] = gts # all ground truth captions of each images
260
+ data['bounds'] = {'it_pos_now': it_pos_now, # the it_pos_now of the last sample
261
+ 'it_max': len(self.split_ix[split]), 'wrapped': wrapped}
262
+ data['infos'] = infos
263
+
264
+ data = {k:torch.from_numpy(v) if type(v) is np.ndarray else v for k,v in data.items()} # Turn all ndarray to torch tensor
265
+
266
+ return data
267
+
268
+ def __getitem__(self, index):
269
+ """This function returns a tuple that is further passed to collate_fn
270
+ """
271
+ ix, it_pos_now, wrapped = index #self.split_ix[index]
272
+ if self.use_att:
273
+ att_feat = self.att_loader.get(str(self.info['images'][ix]['id']))
274
+ # Reshape to K x C
275
+ att_feat = att_feat.reshape(-1, att_feat.shape[-1])
276
+ if self.norm_att_feat:
277
+ att_feat = att_feat / np.linalg.norm(att_feat, 2, 1, keepdims=True)
278
+ if self.use_box:
279
+ box_feat = self.box_loader.get(str(self.info['images'][ix]['id']))
280
+ # devided by image width and height
281
+ x1,y1,x2,y2 = np.hsplit(box_feat, 4)
282
+ h,w = self.info['images'][ix]['height'], self.info['images'][ix]['width']
283
+ box_feat = np.hstack((x1/w, y1/h, x2/w, y2/h, (x2-x1)*(y2-y1)/(w*h))) # question? x2-x1+1??
284
+ if self.norm_box_feat:
285
+ box_feat = box_feat / np.linalg.norm(box_feat, 2, 1, keepdims=True)
286
+ att_feat = np.hstack([att_feat, box_feat])
287
+ # sort the features by the size of boxes
288
+ att_feat = np.stack(sorted(att_feat, key=lambda x:x[-1], reverse=True))
289
+ else:
290
+ att_feat = np.zeros((0,0), dtype='float32')
291
+ if self.use_fc:
292
+ try:
293
+ fc_feat = self.fc_loader.get(str(self.info['images'][ix]['id']))
294
+ except:
295
+ # Use average of attention when there is no fc provided (For bottomup feature)
296
+ fc_feat = att_feat.mean(0)
297
+ else:
298
+ fc_feat = np.zeros((0), dtype='float32')
299
+ if hasattr(self, 'h5_label_file'):
300
+ seq = self.get_captions(ix, self.seq_per_img)
301
+ else:
302
+ seq = None
303
+ return (fc_feat,
304
+ att_feat, seq,
305
+ ix, it_pos_now, wrapped)
306
+
307
+ def __len__(self):
308
+ return len(self.info['images'])
309
+
310
+ class DataLoader:
311
+ def __init__(self, opt):
312
+ self.opt = opt
313
+ self.batch_size = self.opt.batch_size
314
+ self.dataset = Dataset(opt)
315
+
316
+ # Initialize loaders and iters
317
+ self.loaders, self.iters = {}, {}
318
+ for split in ['train', 'val', 'test']:
319
+ if split == 'train':
320
+ sampler = MySampler(self.dataset.split_ix[split], shuffle=True, wrap=True)
321
+ else:
322
+ sampler = MySampler(self.dataset.split_ix[split], shuffle=False, wrap=False)
323
+ self.loaders[split] = data.DataLoader(dataset=self.dataset,
324
+ batch_size=self.batch_size,
325
+ sampler=sampler,
326
+ pin_memory=True,
327
+ num_workers=4, # 4 is usually enough
328
+ collate_fn=partial(self.dataset.collate_func, split=split),
329
+ drop_last=False)
330
+ self.iters[split] = iter(self.loaders[split])
331
+
332
+ def get_batch(self, split):
333
+ try:
334
+ data = next(self.iters[split])
335
+ except StopIteration:
336
+ self.iters[split] = iter(self.loaders[split])
337
+ data = next(self.iters[split])
338
+ return data
339
+
340
+ def reset_iterator(self, split):
341
+ self.loaders[split].sampler._reset_iter()
342
+ self.iters[split] = iter(self.loaders[split])
343
+
344
+ def get_vocab_size(self):
345
+ return self.dataset.get_vocab_size()
346
+
347
+ @property
348
+ def vocab_size(self):
349
+ return self.get_vocab_size()
350
+
351
+ def get_vocab(self):
352
+ return self.dataset.get_vocab()
353
+
354
+ def get_seq_length(self):
355
+ return self.dataset.get_seq_length()
356
+
357
+ @property
358
+ def seq_length(self):
359
+ return self.get_seq_length()
360
+
361
+ def state_dict(self):
362
+ def get_prefetch_num(split):
363
+ if self.loaders[split].num_workers > 0:
364
+ return (self.iters[split]._send_idx - self.iters[split]._rcvd_idx) * self.batch_size
365
+ else:
366
+ return 0
367
+ return {split: loader.sampler.state_dict(get_prefetch_num(split)) \
368
+ for split, loader in self.loaders.items()}
369
+
370
+ def load_state_dict(self, state_dict=None):
371
+ if state_dict is None:
372
+ return
373
+ for split in self.loaders.keys():
374
+ self.loaders[split].sampler.load_state_dict(state_dict[split])
375
+
376
+
377
+ class MySampler(data.sampler.Sampler):
378
+ def __init__(self, index_list, shuffle, wrap):
379
+ self.index_list = index_list
380
+ self.shuffle = shuffle
381
+ self.wrap = wrap
382
+ # if wrap, there will be not stop iteration called
383
+ # wrap True used during training, and wrap False used during test.
384
+ self._reset_iter()
385
+
386
+ def __iter__(self):
387
+ return self
388
+
389
+ def __next__(self):
390
+ wrapped = False
391
+ if self.iter_counter == len(self._index_list):
392
+ self._reset_iter()
393
+ if self.wrap:
394
+ wrapped = True
395
+ else:
396
+ raise StopIteration()
397
+ if len(self._index_list) == 0: # overflow when 0 samples
398
+ return None
399
+ elem = (self._index_list[self.iter_counter], self.iter_counter+1, wrapped)
400
+ self.iter_counter += 1
401
+ return elem
402
+
403
+ def next(self):
404
+ return self.__next__()
405
+
406
+ def _reset_iter(self):
407
+ np.random.seed(0)
408
+ if self.shuffle:
409
+ rand_perm = npr.permutation(len(self.index_list))
410
+ self._index_list = [self.index_list[_] for _ in rand_perm]
411
+ else:
412
+ self._index_list = self.index_list
413
+
414
+ self.iter_counter = 0
415
+
416
+ def __len__(self):
417
+ return len(self.index_list)
418
+
419
+ def load_state_dict(self, state_dict=None):
420
+ if state_dict is None:
421
+ return
422
+ self._index_list = state_dict['index_list']
423
+ self.iter_counter = state_dict['iter_counter']
424
+
425
+ def state_dict(self, prefetched_num=None):
426
+ prefetched_num = prefetched_num or 0
427
+ return {
428
+ 'index_list': self._index_list,
429
+ 'iter_counter': self.iter_counter - prefetched_num
430
+ }
431
+
432
+
captioning/data/dataloaderraw.py ADDED
@@ -0,0 +1,151 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import absolute_import
2
+ from __future__ import division
3
+ from __future__ import print_function
4
+
5
+ import json
6
+ import h5py
7
+ import os
8
+ import numpy as np
9
+ import random
10
+ import torch
11
+ import skimage
12
+ import skimage.io
13
+ import scipy.misc
14
+
15
+ from torchvision import transforms as trn
16
+ preprocess = trn.Compose([
17
+ #trn.ToTensor(),
18
+ trn.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
19
+ ])
20
+
21
+ # from ..utils.resnet_utils import myResnet
22
+ # from ..utils import resnet
23
+
24
+ from captioning.utils.resnet_utils import myResnet
25
+ from captioning.utils import resnet
26
+
27
+
28
+ class DataLoaderRaw():
29
+
30
+ def __init__(self, opt):
31
+ self.opt = opt
32
+ self.coco_json = opt.get('coco_json', '')
33
+ self.folder_path = opt.get('folder_path', '')
34
+
35
+ self.batch_size = opt.get('batch_size', 1)
36
+ self.seq_per_img = 1
37
+
38
+ # Load resnet
39
+ self.cnn_model = opt.get('cnn_model', 'resnet101')
40
+ self.my_resnet = getattr(resnet, self.cnn_model)()
41
+ self.my_resnet.load_state_dict(torch.load('./data/imagenet_weights/'+self.cnn_model+'.pth'))
42
+ self.my_resnet = myResnet(self.my_resnet)
43
+ self.my_resnet.cuda()
44
+ self.my_resnet.eval()
45
+
46
+
47
+
48
+ # load the json file which contains additional information about the dataset
49
+ print('DataLoaderRaw loading images from folder: ', self.folder_path)
50
+
51
+ self.files = []
52
+ self.ids = []
53
+
54
+ print(len(self.coco_json))
55
+ if len(self.coco_json) > 0:
56
+ print('reading from ' + opt.coco_json)
57
+ # read in filenames from the coco-style json file
58
+ self.coco_annotation = json.load(open(self.coco_json))
59
+ for k,v in enumerate(self.coco_annotation['images']):
60
+ fullpath = os.path.join(self.folder_path, v['file_name'])
61
+ self.files.append(fullpath)
62
+ self.ids.append(v['id'])
63
+ else:
64
+ # read in all the filenames from the folder
65
+ print('listing all images in directory ' + self.folder_path)
66
+ def isImage(f):
67
+ supportedExt = ['.jpg','.JPG','.jpeg','.JPEG','.png','.PNG','.ppm','.PPM']
68
+ for ext in supportedExt:
69
+ start_idx = f.rfind(ext)
70
+ if start_idx >= 0 and start_idx + len(ext) == len(f):
71
+ return True
72
+ return False
73
+
74
+ n = 1
75
+ for root, dirs, files in os.walk(self.folder_path, topdown=False):
76
+ for file in files:
77
+ fullpath = os.path.join(self.folder_path, file)
78
+ if isImage(fullpath):
79
+ self.files.append(fullpath)
80
+ self.ids.append(str(n)) # just order them sequentially
81
+ n = n + 1
82
+
83
+ self.N = len(self.files)
84
+ print('DataLoaderRaw found ', self.N, ' images')
85
+
86
+ self.iterator = 0
87
+
88
+ # Nasty
89
+ self.dataset = self # to fix the bug in eval
90
+
91
+ def get_batch(self, split, batch_size=None):
92
+ batch_size = batch_size or self.batch_size
93
+
94
+ # pick an index of the datapoint to load next
95
+ fc_batch = np.ndarray((batch_size, 2048), dtype = 'float32')
96
+ att_batch = np.ndarray((batch_size, 14, 14, 2048), dtype = 'float32')
97
+ max_index = self.N
98
+ wrapped = False
99
+ infos = []
100
+
101
+ for i in range(batch_size):
102
+ ri = self.iterator
103
+ ri_next = ri + 1
104
+ if ri_next >= max_index:
105
+ ri_next = 0
106
+ wrapped = True
107
+ # wrap back around
108
+ self.iterator = ri_next
109
+
110
+ img = skimage.io.imread(self.files[ri])
111
+
112
+ if len(img.shape) == 2:
113
+ img = img[:,:,np.newaxis]
114
+ img = np.concatenate((img, img, img), axis=2)
115
+
116
+ img = img[:,:,:3].astype('float32')/255.0
117
+ img = torch.from_numpy(img.transpose([2,0,1])).cuda()
118
+ img = preprocess(img)
119
+ with torch.no_grad():
120
+ tmp_fc, tmp_att = self.my_resnet(img)
121
+
122
+ fc_batch[i] = tmp_fc.data.cpu().float().numpy()
123
+ att_batch[i] = tmp_att.data.cpu().float().numpy()
124
+
125
+ info_struct = {}
126
+ info_struct['id'] = self.ids[ri]
127
+ info_struct['file_path'] = self.files[ri]
128
+ infos.append(info_struct)
129
+
130
+ data = {}
131
+ data['fc_feats'] = fc_batch
132
+ data['att_feats'] = att_batch.reshape(batch_size, -1, 2048)
133
+ data['labels'] = np.zeros([batch_size, 0])
134
+ data['masks'] = None
135
+ data['att_masks'] = None
136
+ data['bounds'] = {'it_pos_now': self.iterator, 'it_max': self.N, 'wrapped': wrapped}
137
+ data['infos'] = infos
138
+
139
+ data = {k:torch.from_numpy(v) if type(v) is np.ndarray else v for k,v in data.items()} # Turn all ndarray to torch tensor
140
+
141
+ return data
142
+
143
+ def reset_iterator(self, split):
144
+ self.iterator = 0
145
+
146
+ def get_vocab_size(self):
147
+ return len(self.ix_to_word)
148
+
149
+ def get_vocab(self):
150
+ return self.ix_to_word
151
+
captioning/data/pth_loader.py ADDED
@@ -0,0 +1,300 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import absolute_import
2
+ from __future__ import division
3
+ from __future__ import print_function
4
+
5
+ import json
6
+ import h5py
7
+ from lmdbdict import lmdbdict
8
+ from lmdbdict.methods import DUMPS_FUNC, LOADS_FUNC
9
+ import os
10
+ import numpy as np
11
+ import numpy.random as npr
12
+ import random
13
+
14
+ import torch
15
+ import torch.utils.data as data
16
+
17
+ import multiprocessing
18
+ import six
19
+
20
+ # random.seed(42)
21
+ # torch.manual_seed(42)
22
+ # if torch.cuda.is_available():
23
+ # torch.cuda.manual_seed(42)
24
+
25
+ class HybridLoader:
26
+ """
27
+ If db_path is a director, then use normal file loading
28
+ If lmdb, then load from lmdb
29
+ The loading method depend on extention.
30
+
31
+ in_memory: if in_memory is True, we save all the features in memory
32
+ For individual np(y|z)s, we don't need to do that because the system will do this for us.
33
+ Should be useful for lmdb or h5.
34
+ (Copied this idea from vilbert)
35
+ """
36
+ def __init__(self, db_path, ext, in_memory=False):
37
+ self.db_path = db_path
38
+ self.ext = ext
39
+ if self.ext == '.npy':
40
+ self.loader = lambda x: np.load(six.BytesIO(x))
41
+ else:
42
+ self.loader = lambda x: np.load(six.BytesIO(x))['feat']
43
+ if db_path.endswith('.lmdb'):
44
+ self.db_type = 'lmdb'
45
+ self.lmdb = lmdbdict(db_path, unsafe=True)
46
+ self.lmdb._key_dumps = DUMPS_FUNC['ascii']
47
+ self.lmdb._value_loads = LOADS_FUNC['identity']
48
+ elif db_path.endswith('.pth'): # Assume a key,value dictionary
49
+ self.db_type = 'pth'
50
+ self.feat_file = torch.load(db_path)
51
+ self.loader = lambda x: x
52
+ print('HybridLoader: ext is ignored')
53
+ elif db_path.endswith('h5'):
54
+ self.db_type = 'h5'
55
+ self.loader = lambda x: np.array(x).astype('float32')
56
+ else:
57
+ self.db_type = 'dir'
58
+
59
+ self.in_memory = in_memory
60
+ if self.in_memory:
61
+ self.features = {}
62
+
63
+ def get(self, key):
64
+
65
+ if self.in_memory and key in self.features:
66
+ # We save f_input because we want to save the
67
+ # compressed bytes to save memory
68
+ f_input = self.features[key]
69
+ elif self.db_type == 'lmdb':
70
+ f_input = self.lmdb[key]
71
+ elif self.db_type == 'pth':
72
+ f_input = self.feat_file[key]
73
+ elif self.db_type == 'h5':
74
+ f_input = h5py.File(self.db_path, 'r')[key]
75
+ else:
76
+ f_input = open(os.path.join(self.db_path, key + self.ext), 'rb').read()
77
+
78
+ if self.in_memory and key not in self.features:
79
+ self.features[key] = f_input
80
+
81
+ # load image
82
+ feat = self.loader(f_input)
83
+
84
+ return feat
85
+
86
+ class CaptionDataset(data.Dataset):
87
+
88
+ def get_vocab_size(self):
89
+ return self.vocab_size
90
+
91
+ def get_vocab(self):
92
+ return self.ix_to_word
93
+
94
+ def get_seq_length(self):
95
+ return self.seq_length
96
+
97
+ def __init__(self, opt):
98
+ self.opt = opt
99
+ self.seq_per_img = opt.seq_per_img
100
+
101
+ # feature related options
102
+ self.use_fc = getattr(opt, 'use_fc', True)
103
+ self.use_att = getattr(opt, 'use_att', True)
104
+ self.use_box = getattr(opt, 'use_box', 0)
105
+ self.norm_att_feat = getattr(opt, 'norm_att_feat', 0)
106
+ self.norm_box_feat = getattr(opt, 'norm_box_feat', 0)
107
+
108
+ # load the json file which contains additional information about the dataset
109
+ print('DataLoader loading json file: ', opt.input_json)
110
+ self.info = json.load(open(self.opt.input_json))
111
+ if 'ix_to_word' in self.info:
112
+ self.ix_to_word = self.info['ix_to_word']
113
+ self.vocab_size = len(self.ix_to_word)
114
+ print('vocab size is ', self.vocab_size)
115
+
116
+ # open the hdf5 file
117
+ print('DataLoader loading h5 file: ', opt.input_fc_dir, opt.input_att_dir, opt.input_box_dir, opt.input_label_h5)
118
+ """
119
+ Setting input_label_h5 to none is used when only doing generation.
120
+ For example, when you need to test on coco test set.
121
+ """
122
+ if self.opt.input_label_h5 != 'none':
123
+ self.h5_label_file = h5py.File(self.opt.input_label_h5, 'r', driver='core')
124
+ # load in the sequence data
125
+ seq_size = self.h5_label_file['labels'].shape
126
+ self.label = self.h5_label_file['labels'][:]
127
+ self.seq_length = seq_size[1]
128
+ print('max sequence length in data is', self.seq_length)
129
+ # load the pointers in full to RAM (should be small enough)
130
+ self.label_start_ix = self.h5_label_file['label_start_ix'][:]
131
+ self.label_end_ix = self.h5_label_file['label_end_ix'][:]
132
+ else:
133
+ self.seq_length = 1
134
+
135
+ self.data_in_memory = getattr(opt, 'data_in_memory', False)
136
+ self.fc_loader = HybridLoader(self.opt.input_fc_dir, '.npy', in_memory=self.data_in_memory)
137
+ self.att_loader = HybridLoader(self.opt.input_att_dir, '.npz', in_memory=self.data_in_memory)
138
+ self.box_loader = HybridLoader(self.opt.input_box_dir, '.npy', in_memory=self.data_in_memory)
139
+
140
+ self.num_images = len(self.info['images']) # self.label_start_ix.shape[0]
141
+ print('read %d image features' %(self.num_images))
142
+
143
+ # separate out indexes for each of the provided splits
144
+ self.split_ix = {'train': [], 'val': [], 'test': []}
145
+ for ix in range(len(self.info['images'])):
146
+ img = self.info['images'][ix]
147
+ if not 'split' in img:
148
+ self.split_ix['train'].append(ix)
149
+ self.split_ix['val'].append(ix)
150
+ self.split_ix['test'].append(ix)
151
+ elif img['split'] == 'train':
152
+ self.split_ix['train'].append(ix)
153
+ elif img['split'] == 'val':
154
+ self.split_ix['val'].append(ix)
155
+ elif img['split'] == 'test':
156
+ self.split_ix['test'].append(ix)
157
+ elif opt.train_only == 0: # restval
158
+ self.split_ix['train'].append(ix)
159
+
160
+ print('assigned %d images to split train' %len(self.split_ix['train']))
161
+ print('assigned %d images to split val' %len(self.split_ix['val']))
162
+ print('assigned %d images to split test' %len(self.split_ix['test']))
163
+
164
+ def get_captions(self, ix, seq_per_img):
165
+ # fetch the sequence labels
166
+ ix1 = self.label_start_ix[ix] - 1 #label_start_ix starts from 1
167
+ ix2 = self.label_end_ix[ix] - 1
168
+ ncap = ix2 - ix1 + 1 # number of captions available for this image
169
+ assert ncap > 0, 'an image does not have any label. this can be handled but right now isn\'t'
170
+
171
+ random.seed(42)
172
+
173
+ if ncap < seq_per_img:
174
+ # we need to subsample (with replacement)
175
+ seq = np.zeros([seq_per_img, self.seq_length], dtype = 'int')
176
+ for q in range(seq_per_img):
177
+ ixl = random.randint(ix1,ix2)
178
+ seq[q, :] = self.label[ixl, :self.seq_length]
179
+ else:
180
+ ixl = random.randint(ix1, ix2 - seq_per_img + 1)
181
+ seq = self.label[ixl: ixl + seq_per_img, :self.seq_length]
182
+
183
+ return seq
184
+
185
+ def collate_func(self, batch):
186
+ seq_per_img = self.seq_per_img
187
+
188
+ fc_batch = []
189
+ att_batch = []
190
+ label_batch = []
191
+
192
+ wrapped = False
193
+
194
+ infos = []
195
+ gts = []
196
+
197
+ for sample in batch:
198
+ # fetch image
199
+ tmp_fc, tmp_att, tmp_seq, \
200
+ ix = sample
201
+
202
+ fc_batch.append(tmp_fc)
203
+ att_batch.append(tmp_att)
204
+
205
+ tmp_label = np.zeros([seq_per_img, self.seq_length + 2], dtype = 'int')
206
+ if hasattr(self, 'h5_label_file'):
207
+ # if there is ground truth
208
+ tmp_label[:, 1 : self.seq_length + 1] = tmp_seq
209
+ label_batch.append(tmp_label)
210
+
211
+ # Used for reward evaluation
212
+ if hasattr(self, 'h5_label_file'):
213
+ # if there is ground truth
214
+ gts.append(self.label[self.label_start_ix[ix] - 1: self.label_end_ix[ix]])
215
+ else:
216
+ gts.append([])
217
+
218
+ # record associated info as well
219
+ info_dict = {}
220
+ info_dict['ix'] = ix
221
+ info_dict['id'] = self.info['images'][ix]['id']
222
+ info_dict['file_path'] = self.info['images'][ix].get('file_path', '')
223
+ infos.append(info_dict)
224
+
225
+ # #sort by att_feat length
226
+ # fc_batch, att_batch, label_batch, gts, infos = \
227
+ # zip(*sorted(zip(fc_batch, att_batch, np.vsplit(label_batch, batch_size), gts, infos), key=lambda x: len(x[1]), reverse=True))
228
+ fc_batch, att_batch, label_batch, gts, infos = \
229
+ zip(*sorted(zip(fc_batch, att_batch, label_batch, gts, infos), key=lambda x: 0, reverse=True))
230
+ data = {}
231
+ data['fc_feats'] = np.stack(fc_batch)
232
+ # merge att_feats
233
+ max_att_len = max([_.shape[0] for _ in att_batch])
234
+ data['att_feats'] = np.zeros([len(att_batch), max_att_len, att_batch[0].shape[1]], dtype = 'float32')
235
+ for i in range(len(att_batch)):
236
+ data['att_feats'][i, :att_batch[i].shape[0]] = att_batch[i]
237
+ data['att_masks'] = np.zeros(data['att_feats'].shape[:2], dtype='float32')
238
+ for i in range(len(att_batch)):
239
+ data['att_masks'][i, :att_batch[i].shape[0]] = 1
240
+ # set att_masks to None if attention features have same length
241
+ if data['att_masks'].sum() == data['att_masks'].size:
242
+ data['att_masks'] = None
243
+
244
+ data['labels'] = np.vstack(label_batch)
245
+ # generate mask
246
+ nonzeros = np.array(list(map(lambda x: (x != 0).sum()+2, data['labels'])))
247
+ mask_batch = np.zeros([data['labels'].shape[0], self.seq_length + 2], dtype = 'float32')
248
+ for ix, row in enumerate(mask_batch):
249
+ row[:nonzeros[ix]] = 1
250
+ data['masks'] = mask_batch
251
+ data['labels'] = data['labels'].reshape(len(batch), seq_per_img, -1)
252
+ data['masks'] = data['masks'].reshape(len(batch), seq_per_img, -1)
253
+
254
+ data['gts'] = gts # all ground truth captions of each images
255
+ data['infos'] = infos
256
+
257
+ data = {k:torch.from_numpy(v) if type(v) is np.ndarray else v for k,v in data.items()} # Turn all ndarray to torch tensor
258
+
259
+ return data
260
+
261
+ def __getitem__(self, ix):
262
+ """This function returns a tuple that is further passed to collate_fn
263
+ """
264
+ if self.use_att:
265
+ att_feat = self.att_loader.get(str(self.info['images'][ix]['id']))
266
+ # Reshape to K x C
267
+ att_feat = att_feat.reshape(-1, att_feat.shape[-1])
268
+ if self.norm_att_feat:
269
+ att_feat = att_feat / np.linalg.norm(att_feat, 2, 1, keepdims=True)
270
+ if self.use_box:
271
+ box_feat = self.box_loader.get(str(self.info['images'][ix]['id']))
272
+ # devided by image width and height
273
+ x1,y1,x2,y2 = np.hsplit(box_feat, 4)
274
+ h,w = self.info['images'][ix]['height'], self.info['images'][ix]['width']
275
+ box_feat = np.hstack((x1/w, y1/h, x2/w, y2/h, (x2-x1)*(y2-y1)/(w*h))) # question? x2-x1+1??
276
+ if self.norm_box_feat:
277
+ box_feat = box_feat / np.linalg.norm(box_feat, 2, 1, keepdims=True)
278
+ att_feat = np.hstack([att_feat, box_feat])
279
+ # sort the features by the size of boxes
280
+ att_feat = np.stack(sorted(att_feat, key=lambda x:x[-1], reverse=True))
281
+ else:
282
+ att_feat = np.zeros((0,0), dtype='float32')
283
+ if self.use_fc:
284
+ try:
285
+ fc_feat = self.fc_loader.get(str(self.info['images'][ix]['id']))
286
+ except:
287
+ # Use average of attention when there is no fc provided (For bottomup feature)
288
+ fc_feat = att_feat.mean(0)
289
+ else:
290
+ fc_feat = np.zeros((0), dtype='float32')
291
+ if hasattr(self, 'h5_label_file'):
292
+ seq = self.get_captions(ix, self.seq_per_img)
293
+ else:
294
+ seq = None
295
+ return (fc_feat,
296
+ att_feat, seq,
297
+ ix)
298
+
299
+ def __len__(self):
300
+ return len(self.info['images'])
captioning/models/.DS_Store ADDED
Binary file (6.15 kB). View file
 
captioning/models/AoAModel.py ADDED
@@ -0,0 +1,234 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Implementation for paper 'Attention on Attention for Image Captioning'
2
+ # https://arxiv.org/abs/1908.06954
3
+
4
+ # RT: Code from original author's repo: https://github.com/husthuaan/AoANet/
5
+
6
+ from __future__ import absolute_import
7
+ from __future__ import division
8
+ from __future__ import print_function
9
+
10
+ import torch
11
+ import torch.nn as nn
12
+ import torch.nn.functional as F
13
+
14
+ from .AttModel import pack_wrapper, AttModel, Attention
15
+ from .TransformerModel import LayerNorm, attention, clones, SublayerConnection, PositionwiseFeedForward
16
+
17
+ class MultiHeadedDotAttention(nn.Module):
18
+ def __init__(self, h, d_model, dropout=0.1, scale=1, project_k_v=1, use_output_layer=1, do_aoa=0, norm_q=0, dropout_aoa=0.3):
19
+ super(MultiHeadedDotAttention, self).__init__()
20
+ assert d_model * scale % h == 0
21
+ # We assume d_v always equals d_k
22
+ self.d_k = d_model * scale // h
23
+ self.h = h
24
+
25
+ # Do we need to do linear projections on K and V?
26
+ self.project_k_v = project_k_v
27
+
28
+ # normalize the query?
29
+ if norm_q:
30
+ self.norm = LayerNorm(d_model)
31
+ else:
32
+ self.norm = lambda x:x
33
+ self.linears = clones(nn.Linear(d_model, d_model * scale), 1 + 2 * project_k_v)
34
+
35
+ # output linear layer after the multi-head attention?
36
+ self.output_layer = nn.Linear(d_model * scale, d_model)
37
+
38
+ # apply aoa after attention?
39
+ self.use_aoa = do_aoa
40
+ if self.use_aoa:
41
+ self.aoa_layer = nn.Sequential(nn.Linear((1 + scale) * d_model, 2 * d_model), nn.GLU())
42
+ # dropout to the input of AoA layer
43
+ if dropout_aoa > 0:
44
+ self.dropout_aoa = nn.Dropout(p=dropout_aoa)
45
+ else:
46
+ self.dropout_aoa = lambda x:x
47
+
48
+ if self.use_aoa or not use_output_layer:
49
+ # AoA doesn't need the output linear layer
50
+ del self.output_layer
51
+ self.output_layer = lambda x:x
52
+
53
+ self.attn = None
54
+ self.dropout = nn.Dropout(p=dropout)
55
+
56
+ def forward(self, query, value, key, mask=None):
57
+ if mask is not None:
58
+ if len(mask.size()) == 2:
59
+ mask = mask.unsqueeze(-2)
60
+ # Same mask applied to all h heads.
61
+ mask = mask.unsqueeze(1)
62
+
63
+ single_query = 0
64
+ if len(query.size()) == 2:
65
+ single_query = 1
66
+ query = query.unsqueeze(1)
67
+
68
+ nbatches = query.size(0)
69
+
70
+ query = self.norm(query)
71
+
72
+ # Do all the linear projections in batch from d_model => h x d_k
73
+ if self.project_k_v == 0:
74
+ query_ = self.linears[0](query).view(nbatches, -1, self.h, self.d_k).transpose(1, 2)
75
+ key_ = key.view(nbatches, -1, self.h, self.d_k).transpose(1, 2)
76
+ value_ = value.view(nbatches, -1, self.h, self.d_k).transpose(1, 2)
77
+ else:
78
+ query_, key_, value_ = \
79
+ [l(x).view(nbatches, -1, self.h, self.d_k).transpose(1, 2)
80
+ for l, x in zip(self.linears, (query, key, value))]
81
+
82
+
83
+ # Apply attention on all the projected vectors in batch.
84
+ x, self.attn = attention(query_, key_, value_, mask=mask,
85
+ dropout=self.dropout)
86
+
87
+ # "Concat" using a view
88
+ x = x.transpose(1, 2).contiguous() \
89
+ .view(nbatches, -1, self.h * self.d_k)
90
+
91
+ if self.use_aoa:
92
+ # Apply AoA
93
+ x = self.aoa_layer(self.dropout_aoa(torch.cat([x, query], -1)))
94
+ # try:
95
+ # x = self.aoa_layer(self.dropout_aoa(torch.cat([x, query], -1)))
96
+ # except:
97
+ # x = self.aoa_layer(self.dropout_aoa(torch.cat([x.view(query.shape), query], -1)))
98
+ # x = self.aoa_layer(self.dropout_aoa(torch.cat([x, query.view(x.shape)], -1)))
99
+
100
+ x = self.output_layer(x)
101
+
102
+ if single_query:
103
+ query = query.squeeze(1)
104
+ x = x.squeeze(1)
105
+ return x
106
+
107
+ class AoA_Refiner_Layer(nn.Module):
108
+ def __init__(self, size, self_attn, feed_forward, dropout):
109
+ super(AoA_Refiner_Layer, self).__init__()
110
+ self.self_attn = self_attn
111
+ self.feed_forward = feed_forward
112
+ self.use_ff = 0
113
+ if self.feed_forward is not None:
114
+ self.use_ff = 1
115
+ self.sublayer = clones(SublayerConnection(size, dropout), 1+self.use_ff)
116
+ self.size = size
117
+
118
+ def forward(self, x, mask):
119
+ x = self.sublayer[0](x, lambda x: self.self_attn(x, x, x, mask))
120
+ return self.sublayer[-1](x, self.feed_forward) if self.use_ff else x
121
+
122
+ class AoA_Refiner_Core(nn.Module):
123
+ def __init__(self, opt):
124
+ super(AoA_Refiner_Core, self).__init__()
125
+ attn = MultiHeadedDotAttention(opt.num_heads, opt.rnn_size, project_k_v=1, scale=opt.multi_head_scale, do_aoa=opt.refine_aoa, norm_q=0, dropout_aoa=getattr(opt, 'dropout_aoa', 0.3))
126
+ layer = AoA_Refiner_Layer(opt.rnn_size, attn, PositionwiseFeedForward(opt.rnn_size, 2048, 0.1) if opt.use_ff else None, 0.1)
127
+ self.layers = clones(layer, 6)
128
+ self.norm = LayerNorm(layer.size)
129
+
130
+ def forward(self, x, mask):
131
+ for layer in self.layers:
132
+ x = layer(x, mask)
133
+ return self.norm(x)
134
+
135
+ class AoA_Decoder_Core(nn.Module):
136
+ def __init__(self, opt):
137
+ super(AoA_Decoder_Core, self).__init__()
138
+ self.drop_prob_lm = opt.drop_prob_lm
139
+ self.d_model = opt.rnn_size
140
+ self.use_multi_head = opt.use_multi_head
141
+ self.multi_head_scale = opt.multi_head_scale
142
+ self.use_ctx_drop = getattr(opt, 'ctx_drop', 0)
143
+ self.out_res = getattr(opt, 'out_res', 0)
144
+ self.decoder_type = getattr(opt, 'decoder_type', 'AoA')
145
+ self.att_lstm = nn.LSTMCell(opt.input_encoding_size + opt.rnn_size, opt.rnn_size) # we, fc, h^2_t-1
146
+ self.out_drop = nn.Dropout(self.drop_prob_lm)
147
+
148
+ if self.decoder_type == 'AoA':
149
+ # AoA layer
150
+ self.att2ctx = nn.Sequential(nn.Linear(self.d_model * opt.multi_head_scale + opt.rnn_size, 2 * opt.rnn_size), nn.GLU())
151
+ elif self.decoder_type == 'LSTM':
152
+ # LSTM layer
153
+ self.att2ctx = nn.LSTMCell(self.d_model * opt.multi_head_scale + opt.rnn_size, opt.rnn_size)
154
+ else:
155
+ # Base linear layer
156
+ self.att2ctx = nn.Sequential(nn.Linear(self.d_model * opt.multi_head_scale + opt.rnn_size, opt.rnn_size), nn.ReLU())
157
+
158
+ # if opt.use_multi_head == 1: # TODO, not implemented for now
159
+ # self.attention = MultiHeadedAddAttention(opt.num_heads, opt.d_model, scale=opt.multi_head_scale)
160
+ if opt.use_multi_head == 2:
161
+ self.attention = MultiHeadedDotAttention(opt.num_heads, opt.rnn_size, project_k_v=0, scale=opt.multi_head_scale, use_output_layer=0, do_aoa=0, norm_q=1)
162
+ else:
163
+ self.attention = Attention(opt)
164
+
165
+ if self.use_ctx_drop:
166
+ self.ctx_drop = nn.Dropout(self.drop_prob_lm)
167
+ else:
168
+ self.ctx_drop = lambda x :x
169
+
170
+ def forward(self, xt, mean_feats, att_feats, p_att_feats, state, att_masks=None):
171
+
172
+ # state[0][1] is the context vector at the last step
173
+ h_att, c_att = self.att_lstm(torch.cat([xt, mean_feats + self.ctx_drop(state[0][1])], 1), (state[0][0], state[1][0]))
174
+
175
+ if self.use_multi_head == 2:
176
+ att = self.attention(h_att, p_att_feats.narrow(2, 0, self.multi_head_scale * self.d_model), p_att_feats.narrow(2, self.multi_head_scale * self.d_model, self.multi_head_scale * self.d_model), att_masks)
177
+ else:
178
+ att = self.attention(h_att, att_feats, p_att_feats, att_masks)
179
+
180
+ ctx_input = torch.cat([att, h_att], 1)
181
+ if self.decoder_type == 'LSTM':
182
+ output, c_logic = self.att2ctx(ctx_input, (state[0][1], state[1][1]))
183
+ state = (torch.stack((h_att, output)), torch.stack((c_att, c_logic)))
184
+ else:
185
+ output = self.att2ctx(ctx_input)
186
+ # save the context vector to state[0][1]
187
+ state = (torch.stack((h_att, output)), torch.stack((c_att, state[1][1])))
188
+
189
+ if self.out_res:
190
+ # add residual connection
191
+ output = output + h_att
192
+
193
+ output = self.out_drop(output)
194
+ return output, state
195
+
196
+ class AoAModel(AttModel):
197
+ def __init__(self, opt):
198
+ super(AoAModel, self).__init__(opt)
199
+ self.num_layers = 2
200
+ # mean pooling
201
+ self.use_mean_feats = getattr(opt, 'mean_feats', 1)
202
+ if opt.use_multi_head == 2:
203
+ del self.ctx2att
204
+ self.ctx2att = nn.Linear(opt.rnn_size, 2 * opt.multi_head_scale * opt.rnn_size)
205
+
206
+ if self.use_mean_feats:
207
+ del self.fc_embed
208
+ if opt.refine:
209
+ self.refiner = AoA_Refiner_Core(opt)
210
+ else:
211
+ self.refiner = lambda x,y : x
212
+ self.core = AoA_Decoder_Core(opt)
213
+
214
+
215
+ def _prepare_feature(self, fc_feats, att_feats, att_masks):
216
+ att_feats, att_masks = self.clip_att(att_feats, att_masks)
217
+
218
+ # embed att feats
219
+ att_feats = pack_wrapper(self.att_embed, att_feats, att_masks)
220
+ att_feats = self.refiner(att_feats, att_masks)
221
+
222
+ if self.use_mean_feats:
223
+ # meaning pooling
224
+ if att_masks is None:
225
+ mean_feats = torch.mean(att_feats, dim=1)
226
+ else:
227
+ mean_feats = (torch.sum(att_feats * att_masks.unsqueeze(-1), 1) / torch.sum(att_masks.unsqueeze(-1), 1))
228
+ else:
229
+ mean_feats = self.fc_embed(fc_feats)
230
+
231
+ # Project the attention feats first to reduce memory and computation.
232
+ p_att_feats = self.ctx2att(att_feats)
233
+
234
+ return mean_feats, att_feats, p_att_feats, att_masks
captioning/models/AttEnsemble.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This file is the implementation for ensemble evaluation.
2
+
3
+ from __future__ import absolute_import
4
+ from __future__ import division
5
+ from __future__ import print_function
6
+
7
+ import numpy as np
8
+ import torch
9
+ import torch.nn as nn
10
+ import torch.nn.functional as F
11
+ from torch.autograd import *
12
+
13
+ from .CaptionModel import CaptionModel
14
+ from .AttModel import pack_wrapper, AttModel
15
+
16
+ class AttEnsemble(AttModel):
17
+ def __init__(self, models, weights=None):
18
+ CaptionModel.__init__(self)
19
+ # super(AttEnsemble, self).__init__()
20
+
21
+ self.models = nn.ModuleList(models)
22
+ self.vocab_size = models[0].vocab_size
23
+ self.seq_length = models[0].seq_length
24
+ self.bad_endings_ix = models[0].bad_endings_ix
25
+ self.ss_prob = 0
26
+ weights = weights or [1.0] * len(self.models)
27
+ self.register_buffer('weights', torch.tensor(weights))
28
+
29
+ def init_hidden(self, batch_size):
30
+ state = [m.init_hidden(batch_size) for m in self.models]
31
+ return self.pack_state(state)
32
+
33
+ def pack_state(self, state):
34
+ self.state_lengths = [len(_) for _ in state]
35
+ return sum([list(_) for _ in state], [])
36
+
37
+ def unpack_state(self, state):
38
+ out = []
39
+ for l in self.state_lengths:
40
+ out.append(state[:l])
41
+ state = state[l:]
42
+ return out
43
+
44
+ def embed(self, it):
45
+ return [m.embed(it) for m in self.models]
46
+
47
+ def core(self, *args):
48
+ return zip(*[m.core(*_) for m, _ in zip(self.models, zip(*args))])
49
+
50
+ def get_logprobs_state(self, it, tmp_fc_feats, tmp_att_feats, tmp_p_att_feats, tmp_att_masks, state, output_logsoftmax=1):
51
+ # 'it' contains a word index
52
+ xt = self.embed(it)
53
+
54
+ state = self.unpack_state(state)
55
+ output, state = self.core(xt, tmp_fc_feats, tmp_att_feats, tmp_p_att_feats, state, tmp_att_masks)
56
+ logprobs = torch.stack([F.softmax(m.logit(output[i]), dim=1) for i,m in enumerate(self.models)], 2).mul(self.weights).div(self.weights.sum()).sum(-1).log()
57
+
58
+ return logprobs, self.pack_state(state)
59
+
60
+ def _prepare_feature(self, *args):
61
+ return tuple(zip(*[m._prepare_feature(*args) for m in self.models]))
62
+
63
+ def _old_sample_beam(self, fc_feats, att_feats, att_masks=None, opt={}):
64
+ beam_size = opt.get('beam_size', 10)
65
+ batch_size = fc_feats.size(0)
66
+
67
+ fc_feats, att_feats, p_att_feats, att_masks = self._prepare_feature(fc_feats, att_feats, att_masks)
68
+
69
+ assert beam_size <= self.vocab_size + 1, 'lets assume this for now, otherwise this corner case causes a few headaches down the road. can be dealt with in future if needed'
70
+ seq = torch.LongTensor(self.seq_length, batch_size).zero_()
71
+ seqLogprobs = torch.FloatTensor(self.seq_length, batch_size, self.vocab_size + 1)
72
+ # lets process every image independently for now, for simplicity
73
+
74
+ self.done_beams = [[] for _ in range(batch_size)]
75
+ for k in range(batch_size):
76
+ state = self.init_hidden(beam_size)
77
+ tmp_fc_feats = [fc_feats[i][k:k+1].expand(beam_size, fc_feats[i].size(1)) for i,m in enumerate(self.models)]
78
+ tmp_att_feats = [att_feats[i][k:k+1].expand(*((beam_size,)+att_feats[i].size()[1:])).contiguous() for i,m in enumerate(self.models)]
79
+ tmp_p_att_feats = [p_att_feats[i][k:k+1].expand(*((beam_size,)+p_att_feats[i].size()[1:])).contiguous() for i,m in enumerate(self.models)]
80
+ tmp_att_masks = [att_masks[i][k:k+1].expand(*((beam_size,)+att_masks[i].size()[1:])).contiguous() if att_masks[i] is not None else att_masks[i] for i,m in enumerate(self.models)]
81
+
82
+ it = fc_feats[0].data.new(beam_size).long().zero_()
83
+ logprobs, state = self.get_logprobs_state(it, tmp_fc_feats, tmp_att_feats, tmp_p_att_feats, tmp_att_masks, state)
84
+
85
+ self.done_beams[k] = self.old_beam_search(state, logprobs, tmp_fc_feats, tmp_att_feats, tmp_p_att_feats, tmp_att_masks, opt=opt)
86
+ seq[:, k] = self.done_beams[k][0]['seq'] # the first beam has highest cumulative score
87
+ seqLogprobs[:, k] = self.done_beams[k][0]['logps']
88
+ # return the samples and their log likelihoods
89
+ return seq.transpose(0, 1), seqLogprobs.transpose(0, 1)
90
+ # return the samples and their log likelihoods
captioning/models/AttModel.py ADDED
@@ -0,0 +1,977 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This file contains Att2in2, AdaAtt, AdaAttMO, UpDown model
2
+
3
+ # AdaAtt is from Knowing When to Look: Adaptive Attention via A Visual Sentinel for Image Captioning
4
+ # https://arxiv.org/abs/1612.01887
5
+ # AdaAttMO is a modified version with maxout lstm
6
+
7
+ # Att2in is from Self-critical Sequence Training for Image Captioning
8
+ # https://arxiv.org/abs/1612.00563
9
+ # In this file we only have Att2in2, which is a slightly different version of att2in,
10
+ # in which the img feature embedding and word embedding is the same as what in adaatt.
11
+
12
+ # UpDown is from Bottom-Up and Top-Down Attention for Image Captioning and VQA
13
+ # https://arxiv.org/abs/1707.07998
14
+ # However, it may not be identical to the author's architecture.
15
+
16
+ from __future__ import absolute_import
17
+ from __future__ import division
18
+ from __future__ import print_function
19
+
20
+ import numpy as np
21
+ import torch
22
+ import torch.nn as nn
23
+ import torch.nn.functional as F
24
+ from . import utils
25
+ from torch.nn.utils.rnn import PackedSequence, pack_padded_sequence, pad_packed_sequence
26
+
27
+ from .CaptionModel import CaptionModel
28
+
29
+ bad_endings = ['a','an','the','in','for','at','of','with','before','after','on','upon','near','to','is','are','am']
30
+ bad_endings += ['UNK', 'has', 'and', 'more']
31
+
32
+ def sort_pack_padded_sequence(input, lengths):
33
+ sorted_lengths, indices = torch.sort(lengths, descending=True)
34
+ tmp = pack_padded_sequence(input[indices], sorted_lengths, batch_first=True)
35
+ inv_ix = indices.clone()
36
+ inv_ix[indices] = torch.arange(0,len(indices)).type_as(inv_ix)
37
+ return tmp, inv_ix
38
+
39
+ def pad_unsort_packed_sequence(input, inv_ix):
40
+ tmp, _ = pad_packed_sequence(input, batch_first=True)
41
+ tmp = tmp[inv_ix]
42
+ return tmp
43
+
44
+ def pack_wrapper(module, att_feats, att_masks):
45
+ if att_masks is not None:
46
+ packed, inv_ix = sort_pack_padded_sequence(att_feats, att_masks.data.long().sum(1))
47
+ return pad_unsort_packed_sequence(PackedSequence(module(packed[0]), packed[1]), inv_ix)
48
+ else:
49
+ return module(att_feats)
50
+
51
+ class AttModel(CaptionModel):
52
+ def __init__(self, opt):
53
+ super(AttModel, self).__init__()
54
+ self.vocab_size = opt.vocab_size
55
+ self.input_encoding_size = opt.input_encoding_size
56
+ #self.rnn_type = opt.rnn_type
57
+ self.rnn_size = opt.rnn_size
58
+ self.num_layers = opt.num_layers
59
+ self.drop_prob_lm = opt.drop_prob_lm
60
+ self.seq_length = getattr(opt, 'max_length', 16) or opt.seq_length # maximum sample length
61
+ self.fc_feat_size = opt.fc_feat_size
62
+ self.att_feat_size = opt.att_feat_size
63
+ self.att_hid_size = opt.att_hid_size
64
+
65
+ self.bos_idx = getattr(opt, 'bos_idx', 0)
66
+ self.eos_idx = getattr(opt, 'eos_idx', 0)
67
+ self.pad_idx = getattr(opt, 'pad_idx', 0)
68
+
69
+ self.use_bn = getattr(opt, 'use_bn', 0)
70
+
71
+ self.ss_prob = 0.0 # Schedule sampling probability
72
+
73
+ self.embed = nn.Sequential(nn.Embedding(self.vocab_size + 1, self.input_encoding_size),
74
+ nn.ReLU(),
75
+ nn.Dropout(self.drop_prob_lm))
76
+ self.fc_embed = nn.Sequential(nn.Linear(self.fc_feat_size, self.rnn_size),
77
+ nn.ReLU(),
78
+ nn.Dropout(self.drop_prob_lm))
79
+ self.att_embed = nn.Sequential(*(
80
+ ((nn.BatchNorm1d(self.att_feat_size),) if self.use_bn else ())+
81
+ (nn.Linear(self.att_feat_size, self.rnn_size),
82
+ nn.ReLU(),
83
+ nn.Dropout(self.drop_prob_lm))+
84
+ ((nn.BatchNorm1d(self.rnn_size),) if self.use_bn==2 else ())))
85
+
86
+ self.logit_layers = getattr(opt, 'logit_layers', 1)
87
+ if self.logit_layers == 1:
88
+ self.logit = nn.Linear(self.rnn_size, self.vocab_size + 1)
89
+ else:
90
+ self.logit = [[nn.Linear(self.rnn_size, self.rnn_size), nn.ReLU(), nn.Dropout(0.5)] for _ in range(opt.logit_layers - 1)]
91
+ self.logit = nn.Sequential(*(reduce(lambda x,y:x+y, self.logit) + [nn.Linear(self.rnn_size, self.vocab_size + 1)]))
92
+ self.ctx2att = nn.Linear(self.rnn_size, self.att_hid_size)
93
+
94
+ # For remove bad endding
95
+ self.vocab = opt.vocab
96
+ self.bad_endings_ix = [int(k) for k,v in self.vocab.items() if v in bad_endings]
97
+
98
+ def init_hidden(self, bsz):
99
+ weight = self.logit.weight \
100
+ if hasattr(self.logit, "weight") \
101
+ else self.logit[0].weight
102
+ return (weight.new_zeros(self.num_layers, bsz, self.rnn_size),
103
+ weight.new_zeros(self.num_layers, bsz, self.rnn_size))
104
+
105
+ def clip_att(self, att_feats, att_masks):
106
+ # Clip the length of att_masks and att_feats to the maximum length
107
+ if att_masks is not None:
108
+ max_len = att_masks.data.long().sum(1).max()
109
+ att_feats = att_feats[:, :max_len].contiguous()
110
+ att_masks = att_masks[:, :max_len].contiguous()
111
+ return att_feats, att_masks
112
+
113
+ def _prepare_feature(self, fc_feats, att_feats, att_masks):
114
+ att_feats, att_masks = self.clip_att(att_feats, att_masks)
115
+
116
+ # embed fc and att feats
117
+ fc_feats = self.fc_embed(fc_feats)
118
+ att_feats = pack_wrapper(self.att_embed, att_feats, att_masks)
119
+
120
+ # Project the attention feats first to reduce memory and computation comsumptions.
121
+ p_att_feats = self.ctx2att(att_feats)
122
+
123
+ return fc_feats, att_feats, p_att_feats, att_masks
124
+
125
+ def _forward(self, fc_feats, att_feats, seq, att_masks=None):
126
+
127
+ batch_size = fc_feats.size(0)
128
+ if seq.ndim == 3: # B * seq_per_img * seq_len
129
+ seq = seq.reshape(-1, seq.shape[2])
130
+ seq_per_img = seq.shape[0] // batch_size
131
+ state = self.init_hidden(batch_size*seq_per_img)
132
+
133
+ outputs = fc_feats.new_zeros(batch_size*seq_per_img, seq.size(1), self.vocab_size+1)
134
+
135
+ # Prepare the features
136
+ p_fc_feats, p_att_feats, pp_att_feats, p_att_masks = self._prepare_feature(fc_feats, att_feats, att_masks)
137
+ # pp_att_feats is used for attention, we cache it in advance to reduce computation cost
138
+
139
+ if seq_per_img > 1:
140
+ p_fc_feats, p_att_feats, pp_att_feats, p_att_masks = utils.repeat_tensors(seq_per_img,
141
+ [p_fc_feats, p_att_feats, pp_att_feats, p_att_masks]
142
+ )
143
+
144
+ for i in range(seq.size(1)):
145
+ if self.training and i >= 1 and self.ss_prob > 0.0: # otherwiste no need to sample
146
+ sample_prob = fc_feats.new(batch_size*seq_per_img).uniform_(0, 1)
147
+ sample_mask = sample_prob < self.ss_prob
148
+ if sample_mask.sum() == 0:
149
+ it = seq[:, i].clone()
150
+ else:
151
+ sample_ind = sample_mask.nonzero().view(-1)
152
+ it = seq[:, i].data.clone()
153
+ prob_prev = torch.exp(outputs[:, i-1].detach()) # fetch prev distribution: shape Nx(M+1)
154
+ it.index_copy_(0, sample_ind, torch.multinomial(prob_prev, 1).view(-1).index_select(0, sample_ind))
155
+ else:
156
+ it = seq[:, i].clone()
157
+ # break if all the sequences end
158
+ if i >= 1 and seq[:, i].sum() == 0:
159
+ break
160
+
161
+ output, state = self.get_logprobs_state(it, p_fc_feats, p_att_feats, pp_att_feats, p_att_masks, state)
162
+ outputs[:, i] = output
163
+
164
+ return outputs
165
+
166
+ def get_logprobs_state(self, it, fc_feats, att_feats, p_att_feats, att_masks, state, output_logsoftmax=1):
167
+ # 'it' contains a word index
168
+ xt = self.embed(it)
169
+
170
+ output, state = self.core(xt, fc_feats, att_feats, p_att_feats, state, att_masks)
171
+ if output_logsoftmax:
172
+ logprobs = F.log_softmax(self.logit(output), dim=1)
173
+ else:
174
+ logprobs = self.logit(output)
175
+
176
+ return logprobs, state
177
+
178
+ def _old_sample_beam(self, fc_feats, att_feats, att_masks=None, opt={}):
179
+ beam_size = opt.get('beam_size', 10)
180
+ group_size = opt.get('group_size', 1)
181
+ sample_n = opt.get('sample_n', 10)
182
+ # when sample_n == beam_size then each beam is a sample.
183
+ assert sample_n == 1 or sample_n == beam_size // group_size, 'when beam search, sample_n == 1 or beam search'
184
+ batch_size = fc_feats.size(0)
185
+
186
+ p_fc_feats, p_att_feats, pp_att_feats, p_att_masks = self._prepare_feature(fc_feats, att_feats, att_masks)
187
+
188
+ assert beam_size <= self.vocab_size + 1, 'lets assume this for now, otherwise this corner case causes a few headaches down the road. can be dealt with in future if needed'
189
+ seq = fc_feats.new_full((batch_size*sample_n, self.seq_length), self.pad_idx, dtype=torch.long)
190
+ seqLogprobs = fc_feats.new_zeros(batch_size*sample_n, self.seq_length, self.vocab_size + 1)
191
+ # lets process every image independently for now, for simplicity
192
+
193
+ self.done_beams = [[] for _ in range(batch_size)]
194
+ for k in range(batch_size):
195
+ state = self.init_hidden(beam_size)
196
+ tmp_fc_feats, tmp_att_feats, tmp_p_att_feats, tmp_att_masks = utils.repeat_tensors(beam_size,
197
+ [p_fc_feats[k:k+1], p_att_feats[k:k+1], pp_att_feats[k:k+1], p_att_masks[k:k+1] if att_masks is not None else None]
198
+ )
199
+
200
+ for t in range(1):
201
+ if t == 0: # input <bos>
202
+ it = fc_feats.new_full([beam_size], self.bos_idx, dtype=torch.long)
203
+
204
+ logprobs, state = self.get_logprobs_state(it, tmp_fc_feats, tmp_att_feats, tmp_p_att_feats, tmp_att_masks, state)
205
+
206
+ self.done_beams[k] = self.old_beam_search(state, logprobs, tmp_fc_feats, tmp_att_feats, tmp_p_att_feats, tmp_att_masks, opt=opt)
207
+ if sample_n == beam_size:
208
+ for _n in range(sample_n):
209
+ seq[k*sample_n+_n, :] = self.done_beams[k][_n]['seq']
210
+ seqLogprobs[k*sample_n+_n, :] = self.done_beams[k][_n]['logps']
211
+ else:
212
+ seq[k, :] = self.done_beams[k][0]['seq'] # the first beam has highest cumulative score
213
+ seqLogprobs[k, :] = self.done_beams[k][0]['logps']
214
+ # return the samples and their log likelihoods
215
+ return seq, seqLogprobs
216
+
217
+
218
+ def _sample_beam(self, fc_feats, att_feats, att_masks=None, opt={}):
219
+ beam_size = opt.get('beam_size', 10)
220
+ group_size = opt.get('group_size', 1)
221
+ sample_n = opt.get('sample_n', 10)
222
+ # when sample_n == beam_size then each beam is a sample.
223
+ assert sample_n == 1 or sample_n == beam_size // group_size, 'when beam search, sample_n == 1 or beam search'
224
+ batch_size = fc_feats.size(0)
225
+
226
+ p_fc_feats, p_att_feats, pp_att_feats, p_att_masks = self._prepare_feature(fc_feats, att_feats, att_masks)
227
+
228
+ assert beam_size <= self.vocab_size + 1, 'lets assume this for now, otherwise this corner case causes a few headaches down the road. can be dealt with in future if needed'
229
+ seq = fc_feats.new_full((batch_size*sample_n, self.seq_length), self.pad_idx, dtype=torch.long)
230
+ seqLogprobs = fc_feats.new_zeros(batch_size*sample_n, self.seq_length, self.vocab_size + 1)
231
+ # lets process every image independently for now, for simplicity
232
+
233
+ self.done_beams = [[] for _ in range(batch_size)]
234
+
235
+ state = self.init_hidden(batch_size)
236
+
237
+ # first step, feed bos
238
+ it = fc_feats.new_full([batch_size], self.bos_idx, dtype=torch.long)
239
+ logprobs, state = self.get_logprobs_state(it, p_fc_feats, p_att_feats, pp_att_feats, p_att_masks, state)
240
+
241
+ p_fc_feats, p_att_feats, pp_att_feats, p_att_masks = utils.repeat_tensors(beam_size,
242
+ [p_fc_feats, p_att_feats, pp_att_feats, p_att_masks]
243
+ )
244
+ self.done_beams = self.beam_search(state, logprobs, p_fc_feats, p_att_feats, pp_att_feats, p_att_masks, opt=opt)
245
+
246
+ for k in range(batch_size):
247
+ if sample_n == beam_size:
248
+ for _n in range(sample_n):
249
+ seq_len = self.done_beams[k][_n]['seq'].shape[0]
250
+ seq[k*sample_n+_n, :seq_len] = self.done_beams[k][_n]['seq']
251
+ seqLogprobs[k*sample_n+_n, :seq_len] = self.done_beams[k][_n]['logps']
252
+ else:
253
+ seq_len = self.done_beams[k][0]['seq'].shape[0]
254
+ seq[k, :seq_len] = self.done_beams[k][0]['seq'] # the first beam has highest cumulative score
255
+ seqLogprobs[k, :seq_len] = self.done_beams[k][0]['logps']
256
+ # return the samples and their log likelihoods
257
+ return seq, seqLogprobs
258
+
259
+ def _sample(self, fc_feats, att_feats, att_masks=None, opt={}):
260
+
261
+ sample_method = opt.get('sample_method', 'greedy')
262
+ beam_size = opt.get('beam_size', 1)
263
+ temperature = opt.get('temperature', 1.0)
264
+ sample_n = int(opt.get('sample_n', 1))
265
+ group_size = opt.get('group_size', 1)
266
+ output_logsoftmax = opt.get('output_logsoftmax', 1)
267
+ decoding_constraint = opt.get('decoding_constraint', 0)
268
+ block_trigrams = opt.get('block_trigrams', 0)
269
+ remove_bad_endings = opt.get('remove_bad_endings', 1)
270
+ suppress_UNK = opt.get('suppress_UNK', 1)
271
+
272
+ if beam_size > 1 and sample_method in ['greedy', 'beam_search']:
273
+ return self._sample_beam(fc_feats, att_feats, att_masks, opt)
274
+ if group_size > 1:
275
+ return self._diverse_sample(fc_feats, att_feats, att_masks, opt)
276
+
277
+ batch_size = fc_feats.size(0)
278
+ state = self.init_hidden(batch_size*sample_n)
279
+
280
+ p_fc_feats, p_att_feats, pp_att_feats, p_att_masks = self._prepare_feature(fc_feats, att_feats, att_masks)
281
+
282
+ if sample_n > 1:
283
+ p_fc_feats, p_att_feats, pp_att_feats, p_att_masks = utils.repeat_tensors(sample_n,
284
+ [p_fc_feats, p_att_feats, pp_att_feats, p_att_masks]
285
+ )
286
+
287
+ trigrams = [] # will be a list of batch_size dictionaries
288
+
289
+ seq = fc_feats.new_full((batch_size*sample_n, self.seq_length), self.pad_idx, dtype=torch.long)
290
+ seqLogprobs = fc_feats.new_zeros(batch_size*sample_n, self.seq_length, self.vocab_size + 1)
291
+ for t in range(self.seq_length + 1):
292
+ if t == 0: # input <bos>
293
+ it = fc_feats.new_full([batch_size*sample_n], self.bos_idx, dtype=torch.long)
294
+
295
+ logprobs, state = self.get_logprobs_state(it, p_fc_feats, p_att_feats, pp_att_feats, p_att_masks, state, output_logsoftmax=output_logsoftmax)
296
+
297
+ if decoding_constraint and t > 0:
298
+ tmp = logprobs.new_zeros(logprobs.size())
299
+ tmp.scatter_(1, seq[:,t-1].data.unsqueeze(1), float('-inf'))
300
+ logprobs = logprobs + tmp
301
+
302
+ if remove_bad_endings and t > 0:
303
+ logprobs[torch.from_numpy(np.isin(seq[:,t-1].data.cpu().numpy(), self.bad_endings_ix)), 0] = float('-inf')
304
+ # suppress UNK tokens in the decoding
305
+ if suppress_UNK and hasattr(self, 'vocab') and self.vocab[str(logprobs.size(1)-1)] == 'UNK':
306
+ logprobs[:,logprobs.size(1)-1] = logprobs[:, logprobs.size(1)-1] - 1000
307
+
308
+ # if remove_bad_endings and t > 0:
309
+ # tmp = logprobs.new_zeros(logprobs.size())
310
+ # prev_bad = np.isin(seq[:,t-1].data.cpu().numpy(), self.bad_endings_ix)
311
+ # # Make it impossible to generate bad_endings
312
+ # tmp[torch.from_numpy(prev_bad.astype('uint8')), 0] = float('-inf')
313
+ # logprobs = logprobs + tmp
314
+
315
+ # Mess with trigrams
316
+ # Copy from https://github.com/lukemelas/image-paragraph-captioning
317
+ if block_trigrams and t >= 3:
318
+ # Store trigram generated at last step
319
+ prev_two_batch = seq[:,t-3:t-1]
320
+ for i in range(batch_size): # = seq.size(0)
321
+ prev_two = (prev_two_batch[i][0].item(), prev_two_batch[i][1].item())
322
+ current = seq[i][t-1]
323
+ if t == 3: # initialize
324
+ trigrams.append({prev_two: [current]}) # {LongTensor: list containing 1 int}
325
+ elif t > 3:
326
+ if prev_two in trigrams[i]: # add to list
327
+ trigrams[i][prev_two].append(current)
328
+ else: # create list
329
+ trigrams[i][prev_two] = [current]
330
+ # Block used trigrams at next step
331
+ prev_two_batch = seq[:,t-2:t]
332
+ mask = torch.zeros(logprobs.size(), requires_grad=False).to(logprobs.device) # batch_size x vocab_size
333
+ for i in range(batch_size):
334
+ prev_two = (prev_two_batch[i][0].item(), prev_two_batch[i][1].item())
335
+ if prev_two in trigrams[i]:
336
+ for j in trigrams[i][prev_two]:
337
+ mask[i,j] += 1
338
+ # Apply mask to log probs
339
+ #logprobs = logprobs - (mask * 1e9)
340
+ alpha = 2.0 # = 4
341
+ logprobs = logprobs + (mask * -0.693 * alpha) # ln(1/2) * alpha (alpha -> infty works best)
342
+
343
+ # sample the next word
344
+ if t == self.seq_length: # skip if we achieve maximum length
345
+ break
346
+ it, sampleLogprobs = self.sample_next_word(logprobs, sample_method, temperature)
347
+
348
+ # stop when all finished
349
+ if t == 0:
350
+ unfinished = it != self.eos_idx
351
+ else:
352
+ it[~unfinished] = self.pad_idx # This allows eos_idx not being overwritten to 0
353
+ logprobs = logprobs * unfinished.unsqueeze(1).to(logprobs)
354
+ unfinished = unfinished & (it != self.eos_idx)
355
+ seq[:,t] = it
356
+ seqLogprobs[:,t] = logprobs
357
+ # quit loop if all sequences have finished
358
+ if unfinished.sum() == 0:
359
+ break
360
+ return seq, seqLogprobs
361
+
362
+ def _diverse_sample(self, fc_feats, att_feats, att_masks=None, opt={}):
363
+
364
+ sample_method = opt.get('sample_method', 'greedy')
365
+ beam_size = opt.get('beam_size', 1)
366
+ temperature = opt.get('temperature', 1.0)
367
+ group_size = opt.get('group_size', 1)
368
+ diversity_lambda = opt.get('diversity_lambda', 0.5)
369
+ decoding_constraint = opt.get('decoding_constraint', 0)
370
+ block_trigrams = opt.get('block_trigrams', 0)
371
+ remove_bad_endings = opt.get('remove_bad_endings', 1)
372
+
373
+ batch_size = fc_feats.size(0)
374
+ state = self.init_hidden(batch_size)
375
+
376
+ p_fc_feats, p_att_feats, pp_att_feats, p_att_masks = self._prepare_feature(fc_feats, att_feats, att_masks)
377
+
378
+ trigrams_table = [[] for _ in range(group_size)] # will be a list of batch_size dictionaries
379
+
380
+ seq_table = [fc_feats.new_full((batch_size, self.seq_length), self.pad_idx, dtype=torch.long) for _ in range(group_size)]
381
+ seqLogprobs_table = [fc_feats.new_zeros(batch_size, self.seq_length) for _ in range(group_size)]
382
+ state_table = [self.init_hidden(batch_size) for _ in range(group_size)]
383
+
384
+ for tt in range(self.seq_length + group_size):
385
+ for divm in range(group_size):
386
+ t = tt - divm
387
+ seq = seq_table[divm]
388
+ seqLogprobs = seqLogprobs_table[divm]
389
+ trigrams = trigrams_table[divm]
390
+ if t >= 0 and t <= self.seq_length-1:
391
+ if t == 0: # input <bos>
392
+ it = fc_feats.new_full([batch_size], self.bos_idx, dtype=torch.long)
393
+ else:
394
+ it = seq[:, t-1] # changed
395
+
396
+ logprobs, state_table[divm] = self.get_logprobs_state(it, p_fc_feats, p_att_feats, pp_att_feats, p_att_masks, state_table[divm]) # changed
397
+ logprobs = F.log_softmax(logprobs / temperature, dim=-1)
398
+
399
+ # Add diversity
400
+ if divm > 0:
401
+ unaug_logprobs = logprobs.clone()
402
+ for prev_choice in range(divm):
403
+ prev_decisions = seq_table[prev_choice][:, t]
404
+ logprobs[:, prev_decisions] = logprobs[:, prev_decisions] - diversity_lambda
405
+
406
+ if decoding_constraint and t > 0:
407
+ tmp = logprobs.new_zeros(logprobs.size())
408
+ tmp.scatter_(1, seq[:,t-1].data.unsqueeze(1), float('-inf'))
409
+ logprobs = logprobs + tmp
410
+
411
+ if remove_bad_endings and t > 0:
412
+ tmp = logprobs.new_zeros(logprobs.size())
413
+ prev_bad = np.isin(seq[:,t-1].data.cpu().numpy(), self.bad_endings_ix)
414
+ # Impossible to generate remove_bad_endings
415
+ tmp[torch.from_numpy(prev_bad.astype('uint8')), 0] = float('-inf')
416
+ logprobs = logprobs + tmp
417
+
418
+ # Mess with trigrams
419
+ if block_trigrams and t >= 3:
420
+ # Store trigram generated at last step
421
+ prev_two_batch = seq[:,t-3:t-1]
422
+ for i in range(batch_size): # = seq.size(0)
423
+ prev_two = (prev_two_batch[i][0].item(), prev_two_batch[i][1].item())
424
+ current = seq[i][t-1]
425
+ if t == 3: # initialize
426
+ trigrams.append({prev_two: [current]}) # {LongTensor: list containing 1 int}
427
+ elif t > 3:
428
+ if prev_two in trigrams[i]: # add to list
429
+ trigrams[i][prev_two].append(current)
430
+ else: # create list
431
+ trigrams[i][prev_two] = [current]
432
+ # Block used trigrams at next step
433
+ prev_two_batch = seq[:,t-2:t]
434
+ mask = torch.zeros(logprobs.size(), requires_grad=False).cuda() # batch_size x vocab_size
435
+ for i in range(batch_size):
436
+ prev_two = (prev_two_batch[i][0].item(), prev_two_batch[i][1].item())
437
+ if prev_two in trigrams[i]:
438
+ for j in trigrams[i][prev_two]:
439
+ mask[i,j] += 1
440
+ # Apply mask to log probs
441
+ #logprobs = logprobs - (mask * 1e9)
442
+ alpha = 2.0 # = 4
443
+ logprobs = logprobs + (mask * -0.693 * alpha) # ln(1/2) * alpha (alpha -> infty works best)
444
+
445
+ it, sampleLogprobs = self.sample_next_word(logprobs, sample_method, 1)
446
+
447
+ # stop when all finished
448
+ if t == 0:
449
+ unfinished = it != self.eos_idx
450
+ else:
451
+ unfinished = (seq[:,t-1] != self.pad_idx) & (seq[:,t-1] != self.eos_idx)
452
+ it[~unfinished] = self.pad_idx
453
+ unfinished = unfinished & (it != self.eos_idx) # changed
454
+ seq[:,t] = it
455
+ seqLogprobs[:,t] = sampleLogprobs.view(-1)
456
+
457
+ return torch.stack(seq_table, 1).reshape(batch_size * group_size, -1), torch.stack(seqLogprobs_table, 1).reshape(batch_size * group_size, -1)
458
+
459
+ class AdaAtt_lstm(nn.Module):
460
+ def __init__(self, opt, use_maxout=True):
461
+ super(AdaAtt_lstm, self).__init__()
462
+ self.input_encoding_size = opt.input_encoding_size
463
+ #self.rnn_type = opt.rnn_type
464
+ self.rnn_size = opt.rnn_size
465
+ self.num_layers = opt.num_layers
466
+ self.drop_prob_lm = opt.drop_prob_lm
467
+ self.fc_feat_size = opt.fc_feat_size
468
+ self.att_feat_size = opt.att_feat_size
469
+ self.att_hid_size = opt.att_hid_size
470
+
471
+ self.use_maxout = use_maxout
472
+
473
+ # Build a LSTM
474
+ self.w2h = nn.Linear(self.input_encoding_size, (4+(use_maxout==True)) * self.rnn_size)
475
+ self.v2h = nn.Linear(self.rnn_size, (4+(use_maxout==True)) * self.rnn_size)
476
+
477
+ self.i2h = nn.ModuleList([nn.Linear(self.rnn_size, (4+(use_maxout==True)) * self.rnn_size) for _ in range(self.num_layers - 1)])
478
+ self.h2h = nn.ModuleList([nn.Linear(self.rnn_size, (4+(use_maxout==True)) * self.rnn_size) for _ in range(self.num_layers)])
479
+
480
+ # Layers for getting the fake region
481
+ if self.num_layers == 1:
482
+ self.r_w2h = nn.Linear(self.input_encoding_size, self.rnn_size)
483
+ self.r_v2h = nn.Linear(self.rnn_size, self.rnn_size)
484
+ else:
485
+ self.r_i2h = nn.Linear(self.rnn_size, self.rnn_size)
486
+ self.r_h2h = nn.Linear(self.rnn_size, self.rnn_size)
487
+
488
+
489
+ def forward(self, xt, img_fc, state):
490
+
491
+ hs = []
492
+ cs = []
493
+ for L in range(self.num_layers):
494
+ # c,h from previous timesteps
495
+ prev_h = state[0][L]
496
+ prev_c = state[1][L]
497
+ # the input to this layer
498
+ if L == 0:
499
+ x = xt
500
+ i2h = self.w2h(x) + self.v2h(img_fc)
501
+ else:
502
+ x = hs[-1]
503
+ x = F.dropout(x, self.drop_prob_lm, self.training)
504
+ i2h = self.i2h[L-1](x)
505
+
506
+ all_input_sums = i2h+self.h2h[L](prev_h)
507
+
508
+ sigmoid_chunk = all_input_sums.narrow(1, 0, 3 * self.rnn_size)
509
+ sigmoid_chunk = torch.sigmoid(sigmoid_chunk)
510
+ # decode the gates
511
+ in_gate = sigmoid_chunk.narrow(1, 0, self.rnn_size)
512
+ forget_gate = sigmoid_chunk.narrow(1, self.rnn_size, self.rnn_size)
513
+ out_gate = sigmoid_chunk.narrow(1, self.rnn_size * 2, self.rnn_size)
514
+ # decode the write inputs
515
+ if not self.use_maxout:
516
+ in_transform = torch.tanh(all_input_sums.narrow(1, 3 * self.rnn_size, self.rnn_size))
517
+ else:
518
+ in_transform = all_input_sums.narrow(1, 3 * self.rnn_size, 2 * self.rnn_size)
519
+ in_transform = torch.max(\
520
+ in_transform.narrow(1, 0, self.rnn_size),
521
+ in_transform.narrow(1, self.rnn_size, self.rnn_size))
522
+ # perform the LSTM update
523
+ next_c = forget_gate * prev_c + in_gate * in_transform
524
+ # gated cells form the output
525
+ tanh_nex_c = torch.tanh(next_c)
526
+ next_h = out_gate * tanh_nex_c
527
+ if L == self.num_layers-1:
528
+ if L == 0:
529
+ i2h = self.r_w2h(x) + self.r_v2h(img_fc)
530
+ else:
531
+ i2h = self.r_i2h(x)
532
+ n5 = i2h+self.r_h2h(prev_h)
533
+ fake_region = torch.sigmoid(n5) * tanh_nex_c
534
+
535
+ cs.append(next_c)
536
+ hs.append(next_h)
537
+
538
+ # set up the decoder
539
+ top_h = hs[-1]
540
+ top_h = F.dropout(top_h, self.drop_prob_lm, self.training)
541
+ fake_region = F.dropout(fake_region, self.drop_prob_lm, self.training)
542
+
543
+ state = (torch.cat([_.unsqueeze(0) for _ in hs], 0),
544
+ torch.cat([_.unsqueeze(0) for _ in cs], 0))
545
+ return top_h, fake_region, state
546
+
547
+ class AdaAtt_attention(nn.Module):
548
+ def __init__(self, opt):
549
+ super(AdaAtt_attention, self).__init__()
550
+ self.input_encoding_size = opt.input_encoding_size
551
+ #self.rnn_type = opt.rnn_type
552
+ self.rnn_size = opt.rnn_size
553
+ self.drop_prob_lm = opt.drop_prob_lm
554
+ self.att_hid_size = opt.att_hid_size
555
+
556
+ # fake region embed
557
+ self.fr_linear = nn.Sequential(
558
+ nn.Linear(self.rnn_size, self.input_encoding_size),
559
+ nn.ReLU(),
560
+ nn.Dropout(self.drop_prob_lm))
561
+ self.fr_embed = nn.Linear(self.input_encoding_size, self.att_hid_size)
562
+
563
+ # h out embed
564
+ self.ho_linear = nn.Sequential(
565
+ nn.Linear(self.rnn_size, self.input_encoding_size),
566
+ nn.Tanh(),
567
+ nn.Dropout(self.drop_prob_lm))
568
+ self.ho_embed = nn.Linear(self.input_encoding_size, self.att_hid_size)
569
+
570
+ self.alpha_net = nn.Linear(self.att_hid_size, 1)
571
+ self.att2h = nn.Linear(self.rnn_size, self.rnn_size)
572
+
573
+ def forward(self, h_out, fake_region, conv_feat, conv_feat_embed, att_masks=None):
574
+
575
+ # View into three dimensions
576
+ att_size = conv_feat.numel() // conv_feat.size(0) // self.rnn_size
577
+ conv_feat = conv_feat.view(-1, att_size, self.rnn_size)
578
+ conv_feat_embed = conv_feat_embed.view(-1, att_size, self.att_hid_size)
579
+
580
+ # view neighbor from bach_size * neighbor_num x rnn_size to bach_size x rnn_size * neighbor_num
581
+ fake_region = self.fr_linear(fake_region)
582
+ fake_region_embed = self.fr_embed(fake_region)
583
+
584
+ h_out_linear = self.ho_linear(h_out)
585
+ h_out_embed = self.ho_embed(h_out_linear)
586
+
587
+ txt_replicate = h_out_embed.unsqueeze(1).expand(h_out_embed.size(0), att_size + 1, h_out_embed.size(1))
588
+
589
+ img_all = torch.cat([fake_region.view(-1,1,self.input_encoding_size), conv_feat], 1)
590
+ img_all_embed = torch.cat([fake_region_embed.view(-1,1,self.input_encoding_size), conv_feat_embed], 1)
591
+
592
+ hA = torch.tanh(img_all_embed + txt_replicate)
593
+ hA = F.dropout(hA,self.drop_prob_lm, self.training)
594
+
595
+ hAflat = self.alpha_net(hA.view(-1, self.att_hid_size))
596
+ PI = F.softmax(hAflat.view(-1, att_size + 1), dim=1)
597
+
598
+ if att_masks is not None:
599
+ att_masks = att_masks.view(-1, att_size)
600
+ PI = PI * torch.cat([att_masks[:,:1], att_masks], 1) # assume one one at the first time step.
601
+ PI = PI / PI.sum(1, keepdim=True)
602
+
603
+ visAtt = torch.bmm(PI.unsqueeze(1), img_all)
604
+ visAttdim = visAtt.squeeze(1)
605
+
606
+ atten_out = visAttdim + h_out_linear
607
+
608
+ h = torch.tanh(self.att2h(atten_out))
609
+ h = F.dropout(h, self.drop_prob_lm, self.training)
610
+ return h
611
+
612
+ class AdaAttCore(nn.Module):
613
+ def __init__(self, opt, use_maxout=False):
614
+ super(AdaAttCore, self).__init__()
615
+ self.lstm = AdaAtt_lstm(opt, use_maxout)
616
+ self.attention = AdaAtt_attention(opt)
617
+
618
+ def forward(self, xt, fc_feats, att_feats, p_att_feats, state, att_masks=None):
619
+ h_out, p_out, state = self.lstm(xt, fc_feats, state)
620
+ atten_out = self.attention(h_out, p_out, att_feats, p_att_feats, att_masks)
621
+ return atten_out, state
622
+
623
+ class UpDownCore(nn.Module):
624
+ def __init__(self, opt, use_maxout=False):
625
+ super(UpDownCore, self).__init__()
626
+ self.drop_prob_lm = opt.drop_prob_lm
627
+
628
+ self.att_lstm = nn.LSTMCell(opt.input_encoding_size + opt.rnn_size * 2, opt.rnn_size) # we, fc, h^2_t-1
629
+ self.lang_lstm = nn.LSTMCell(opt.rnn_size * 2, opt.rnn_size) # h^1_t, \hat v
630
+ self.attention = Attention(opt)
631
+
632
+ def forward(self, xt, fc_feats, att_feats, p_att_feats, state, att_masks=None):
633
+ prev_h = state[0][-1]
634
+ att_lstm_input = torch.cat([prev_h, fc_feats, xt], 1)
635
+
636
+ h_att, c_att = self.att_lstm(att_lstm_input, (state[0][0], state[1][0]))
637
+
638
+ att = self.attention(h_att, att_feats, p_att_feats, att_masks)
639
+
640
+ lang_lstm_input = torch.cat([att, h_att], 1)
641
+ # lang_lstm_input = torch.cat([att, F.dropout(h_att, self.drop_prob_lm, self.training)], 1) ?????
642
+
643
+ h_lang, c_lang = self.lang_lstm(lang_lstm_input, (state[0][1], state[1][1]))
644
+
645
+ output = F.dropout(h_lang, self.drop_prob_lm, self.training)
646
+ state = (torch.stack([h_att, h_lang]), torch.stack([c_att, c_lang]))
647
+
648
+ return output, state
649
+
650
+
651
+ ############################################################################
652
+ # Notice:
653
+ # StackAtt and DenseAtt are models that I randomly designed.
654
+ # They are not related to any paper.
655
+ ############################################################################
656
+
657
+ from .FCModel import LSTMCore
658
+ class StackAttCore(nn.Module):
659
+ def __init__(self, opt, use_maxout=False):
660
+ super(StackAttCore, self).__init__()
661
+ self.drop_prob_lm = opt.drop_prob_lm
662
+
663
+ # self.att0 = Attention(opt)
664
+ self.att1 = Attention(opt)
665
+ self.att2 = Attention(opt)
666
+
667
+ opt_input_encoding_size = opt.input_encoding_size
668
+ opt.input_encoding_size = opt.input_encoding_size + opt.rnn_size
669
+ self.lstm0 = LSTMCore(opt) # att_feat + word_embedding
670
+ opt.input_encoding_size = opt.rnn_size * 2
671
+ self.lstm1 = LSTMCore(opt)
672
+ self.lstm2 = LSTMCore(opt)
673
+ opt.input_encoding_size = opt_input_encoding_size
674
+
675
+ # self.emb1 = nn.Linear(opt.rnn_size, opt.rnn_size)
676
+ self.emb2 = nn.Linear(opt.rnn_size, opt.rnn_size)
677
+
678
+ def forward(self, xt, fc_feats, att_feats, p_att_feats, state, att_masks=None):
679
+ # att_res_0 = self.att0(state[0][-1], att_feats, p_att_feats, att_masks)
680
+ h_0, state_0 = self.lstm0(torch.cat([xt,fc_feats],1), [state[0][0:1], state[1][0:1]])
681
+ att_res_1 = self.att1(h_0, att_feats, p_att_feats, att_masks)
682
+ h_1, state_1 = self.lstm1(torch.cat([h_0,att_res_1],1), [state[0][1:2], state[1][1:2]])
683
+ att_res_2 = self.att2(h_1 + self.emb2(att_res_1), att_feats, p_att_feats, att_masks)
684
+ h_2, state_2 = self.lstm2(torch.cat([h_1,att_res_2],1), [state[0][2:3], state[1][2:3]])
685
+
686
+ return h_2, [torch.cat(_, 0) for _ in zip(state_0, state_1, state_2)]
687
+
688
+ class DenseAttCore(nn.Module):
689
+ def __init__(self, opt, use_maxout=False):
690
+ super(DenseAttCore, self).__init__()
691
+ self.drop_prob_lm = opt.drop_prob_lm
692
+
693
+ # self.att0 = Attention(opt)
694
+ self.att1 = Attention(opt)
695
+ self.att2 = Attention(opt)
696
+
697
+ opt_input_encoding_size = opt.input_encoding_size
698
+ opt.input_encoding_size = opt.input_encoding_size + opt.rnn_size
699
+ self.lstm0 = LSTMCore(opt) # att_feat + word_embedding
700
+ opt.input_encoding_size = opt.rnn_size * 2
701
+ self.lstm1 = LSTMCore(opt)
702
+ self.lstm2 = LSTMCore(opt)
703
+ opt.input_encoding_size = opt_input_encoding_size
704
+
705
+ # self.emb1 = nn.Linear(opt.rnn_size, opt.rnn_size)
706
+ self.emb2 = nn.Linear(opt.rnn_size, opt.rnn_size)
707
+
708
+ # fuse h_0 and h_1
709
+ self.fusion1 = nn.Sequential(nn.Linear(opt.rnn_size*2, opt.rnn_size),
710
+ nn.ReLU(),
711
+ nn.Dropout(opt.drop_prob_lm))
712
+ # fuse h_0, h_1 and h_2
713
+ self.fusion2 = nn.Sequential(nn.Linear(opt.rnn_size*3, opt.rnn_size),
714
+ nn.ReLU(),
715
+ nn.Dropout(opt.drop_prob_lm))
716
+
717
+ def forward(self, xt, fc_feats, att_feats, p_att_feats, state, att_masks=None):
718
+ # att_res_0 = self.att0(state[0][-1], att_feats, p_att_feats, att_masks)
719
+ h_0, state_0 = self.lstm0(torch.cat([xt,fc_feats],1), [state[0][0:1], state[1][0:1]])
720
+ att_res_1 = self.att1(h_0, att_feats, p_att_feats, att_masks)
721
+ h_1, state_1 = self.lstm1(torch.cat([h_0,att_res_1],1), [state[0][1:2], state[1][1:2]])
722
+ att_res_2 = self.att2(h_1 + self.emb2(att_res_1), att_feats, p_att_feats, att_masks)
723
+ h_2, state_2 = self.lstm2(torch.cat([self.fusion1(torch.cat([h_0, h_1], 1)),att_res_2],1), [state[0][2:3], state[1][2:3]])
724
+
725
+ return self.fusion2(torch.cat([h_0, h_1, h_2], 1)), [torch.cat(_, 0) for _ in zip(state_0, state_1, state_2)]
726
+
727
+ class Attention(nn.Module):
728
+ def __init__(self, opt):
729
+ super(Attention, self).__init__()
730
+ self.rnn_size = opt.rnn_size
731
+ self.att_hid_size = opt.att_hid_size
732
+
733
+ self.h2att = nn.Linear(self.rnn_size, self.att_hid_size)
734
+ self.alpha_net = nn.Linear(self.att_hid_size, 1)
735
+
736
+ def forward(self, h, att_feats, p_att_feats, att_masks=None):
737
+ # The p_att_feats here is already projected
738
+ att_size = att_feats.numel() // att_feats.size(0) // att_feats.size(-1)
739
+ att = p_att_feats.view(-1, att_size, self.att_hid_size)
740
+
741
+ att_h = self.h2att(h) # batch * att_hid_size
742
+ att_h = att_h.unsqueeze(1).expand_as(att) # batch * att_size * att_hid_size
743
+ dot = att + att_h # batch * att_size * att_hid_size
744
+ dot = torch.tanh(dot) # batch * att_size * att_hid_size
745
+ dot = dot.view(-1, self.att_hid_size) # (batch * att_size) * att_hid_size
746
+ dot = self.alpha_net(dot) # (batch * att_size) * 1
747
+ dot = dot.view(-1, att_size) # batch * att_size
748
+
749
+ weight = F.softmax(dot, dim=1) # batch * att_size
750
+ if att_masks is not None:
751
+ weight = weight * att_masks.view(-1, att_size).to(weight)
752
+ weight = weight / weight.sum(1, keepdim=True) # normalize to 1
753
+ att_feats_ = att_feats.view(-1, att_size, att_feats.size(-1)) # batch * att_size * att_feat_size
754
+ att_res = torch.bmm(weight.unsqueeze(1), att_feats_).squeeze(1) # batch * att_feat_size
755
+
756
+ return att_res
757
+
758
+ class Att2in2Core(nn.Module):
759
+ def __init__(self, opt):
760
+ super(Att2in2Core, self).__init__()
761
+ self.input_encoding_size = opt.input_encoding_size
762
+ #self.rnn_type = opt.rnn_type
763
+ self.rnn_size = opt.rnn_size
764
+ #self.num_layers = opt.num_layers
765
+ self.drop_prob_lm = opt.drop_prob_lm
766
+ self.fc_feat_size = opt.fc_feat_size
767
+ self.att_feat_size = opt.att_feat_size
768
+ self.att_hid_size = opt.att_hid_size
769
+
770
+ # Build a LSTM
771
+ self.a2c = nn.Linear(self.rnn_size, 2 * self.rnn_size)
772
+ self.i2h = nn.Linear(self.input_encoding_size, 5 * self.rnn_size)
773
+ self.h2h = nn.Linear(self.rnn_size, 5 * self.rnn_size)
774
+ self.dropout = nn.Dropout(self.drop_prob_lm)
775
+
776
+ self.attention = Attention(opt)
777
+
778
+ def forward(self, xt, fc_feats, att_feats, p_att_feats, state, att_masks=None):
779
+ att_res = self.attention(state[0][-1], att_feats, p_att_feats, att_masks)
780
+
781
+ all_input_sums = self.i2h(xt) + self.h2h(state[0][-1])
782
+ sigmoid_chunk = all_input_sums.narrow(1, 0, 3 * self.rnn_size)
783
+ sigmoid_chunk = torch.sigmoid(sigmoid_chunk)
784
+ in_gate = sigmoid_chunk.narrow(1, 0, self.rnn_size)
785
+ forget_gate = sigmoid_chunk.narrow(1, self.rnn_size, self.rnn_size)
786
+ out_gate = sigmoid_chunk.narrow(1, self.rnn_size * 2, self.rnn_size)
787
+
788
+ in_transform = all_input_sums.narrow(1, 3 * self.rnn_size, 2 * self.rnn_size) + \
789
+ self.a2c(att_res)
790
+ in_transform = torch.max(\
791
+ in_transform.narrow(1, 0, self.rnn_size),
792
+ in_transform.narrow(1, self.rnn_size, self.rnn_size))
793
+ next_c = forget_gate * state[1][-1] + in_gate * in_transform
794
+ next_h = out_gate * torch.tanh(next_c)
795
+
796
+ output = self.dropout(next_h)
797
+ state = (next_h.unsqueeze(0), next_c.unsqueeze(0))
798
+ return output, state
799
+
800
+ class Att2inCore(Att2in2Core):
801
+ def __init__(self, opt):
802
+ super(Att2inCore, self).__init__(opt)
803
+ del self.a2c
804
+ self.a2c = nn.Linear(self.att_feat_size, 2 * self.rnn_size)
805
+
806
+ """
807
+ Note this is my attempt to replicate att2all model in self-critical paper.
808
+ However, this is not a correct replication actually. Will fix it.
809
+ """
810
+ class Att2all2Core(nn.Module):
811
+ def __init__(self, opt):
812
+ super(Att2all2Core, self).__init__()
813
+ self.input_encoding_size = opt.input_encoding_size
814
+ #self.rnn_type = opt.rnn_type
815
+ self.rnn_size = opt.rnn_size
816
+ #self.num_layers = opt.num_layers
817
+ self.drop_prob_lm = opt.drop_prob_lm
818
+ self.fc_feat_size = opt.fc_feat_size
819
+ self.att_feat_size = opt.att_feat_size
820
+ self.att_hid_size = opt.att_hid_size
821
+
822
+ # Build a LSTM
823
+ self.a2h = nn.Linear(self.rnn_size, 5 * self.rnn_size)
824
+ self.i2h = nn.Linear(self.input_encoding_size, 5 * self.rnn_size)
825
+ self.h2h = nn.Linear(self.rnn_size, 5 * self.rnn_size)
826
+ self.dropout = nn.Dropout(self.drop_prob_lm)
827
+
828
+ self.attention = Attention(opt)
829
+
830
+ def forward(self, xt, fc_feats, att_feats, p_att_feats, state, att_masks=None):
831
+ att_res = self.attention(state[0][-1], att_feats, p_att_feats, att_masks)
832
+
833
+ all_input_sums = self.i2h(xt) + self.h2h(state[0][-1]) + self.a2h(att_res)
834
+ sigmoid_chunk = all_input_sums.narrow(1, 0, 3 * self.rnn_size)
835
+ sigmoid_chunk = torch.sigmoid(sigmoid_chunk)
836
+ in_gate = sigmoid_chunk.narrow(1, 0, self.rnn_size)
837
+ forget_gate = sigmoid_chunk.narrow(1, self.rnn_size, self.rnn_size)
838
+ out_gate = sigmoid_chunk.narrow(1, self.rnn_size * 2, self.rnn_size)
839
+
840
+ in_transform = all_input_sums.narrow(1, 3 * self.rnn_size, 2 * self.rnn_size)
841
+ in_transform = torch.max(\
842
+ in_transform.narrow(1, 0, self.rnn_size),
843
+ in_transform.narrow(1, self.rnn_size, self.rnn_size))
844
+ next_c = forget_gate * state[1][-1] + in_gate * in_transform
845
+ next_h = out_gate * torch.tanh(next_c)
846
+
847
+ output = self.dropout(next_h)
848
+ state = (next_h.unsqueeze(0), next_c.unsqueeze(0))
849
+ return output, state
850
+
851
+ class AdaAttModel(AttModel):
852
+ def __init__(self, opt):
853
+ super(AdaAttModel, self).__init__(opt)
854
+ self.core = AdaAttCore(opt)
855
+
856
+ # AdaAtt with maxout lstm
857
+ class AdaAttMOModel(AttModel):
858
+ def __init__(self, opt):
859
+ super(AdaAttMOModel, self).__init__(opt)
860
+ self.core = AdaAttCore(opt, True)
861
+
862
+ class Att2in2Model(AttModel):
863
+ def __init__(self, opt):
864
+ super(Att2in2Model, self).__init__(opt)
865
+ self.core = Att2in2Core(opt)
866
+ delattr(self, 'fc_embed')
867
+ self.fc_embed = lambda x : x
868
+
869
+ class Att2all2Model(AttModel):
870
+ def __init__(self, opt):
871
+ super(Att2all2Model, self).__init__(opt)
872
+ self.core = Att2all2Core(opt)
873
+ delattr(self, 'fc_embed')
874
+ self.fc_embed = lambda x : x
875
+
876
+ class UpDownModel(AttModel):
877
+ def __init__(self, opt):
878
+ super(UpDownModel, self).__init__(opt)
879
+ self.num_layers = 2
880
+ self.core = UpDownCore(opt)
881
+
882
+ class StackAttModel(AttModel):
883
+ def __init__(self, opt):
884
+ super(StackAttModel, self).__init__(opt)
885
+ self.num_layers = 3
886
+ self.core = StackAttCore(opt)
887
+
888
+ class DenseAttModel(AttModel):
889
+ def __init__(self, opt):
890
+ super(DenseAttModel, self).__init__(opt)
891
+ self.num_layers = 3
892
+ self.core = DenseAttCore(opt)
893
+
894
+ class Att2inModel(AttModel):
895
+ def __init__(self, opt):
896
+ super(Att2inModel, self).__init__(opt)
897
+ del self.embed, self.fc_embed, self.att_embed
898
+ self.embed = nn.Embedding(self.vocab_size + 1, self.input_encoding_size)
899
+ self.fc_embed = self.att_embed = lambda x: x
900
+ del self.ctx2att
901
+ self.ctx2att = nn.Linear(self.att_feat_size, self.att_hid_size)
902
+ self.core = Att2inCore(opt)
903
+ self.init_weights()
904
+
905
+ def init_weights(self):
906
+ initrange = 0.1
907
+ self.embed.weight.data.uniform_(-initrange, initrange)
908
+ self.logit.bias.data.fill_(0)
909
+ self.logit.weight.data.uniform_(-initrange, initrange)
910
+
911
+
912
+ class NewFCModel(AttModel):
913
+ def __init__(self, opt):
914
+ super(NewFCModel, self).__init__(opt)
915
+ self.fc_embed = nn.Linear(self.fc_feat_size, self.input_encoding_size)
916
+ self.embed = nn.Embedding(self.vocab_size + 1, self.input_encoding_size)
917
+ self._core = LSTMCore(opt)
918
+ delattr(self, 'att_embed')
919
+ self.att_embed = lambda x : x
920
+ delattr(self, 'ctx2att')
921
+ self.ctx2att = lambda x: x
922
+
923
+ def core(self, xt, fc_feats, att_feats, p_att_feats, state, att_masks):
924
+ # Step 0, feed the input image
925
+ # if (self.training and state[0].is_leaf) or \
926
+ # (not self.training and state[0].sum() == 0):
927
+ # _, state = self._core(fc_feats, state)
928
+ # three cases
929
+ # normal mle training
930
+ # Sample
931
+ # beam search (diverse beam search)
932
+ # fixed captioning module.
933
+ is_first_step = (state[0]==0).all(2).all(0) # size: B
934
+ if is_first_step.all():
935
+ _, state = self._core(fc_feats, state)
936
+ elif is_first_step.any():
937
+ # This is mostly for diverse beam search I think
938
+ new_state = [torch.zeros_like(_) for _ in state]
939
+ new_state[0][:, ~is_first_step] = state[0][:, ~is_first_step]
940
+ new_state[1][:, ~is_first_step] = state[1][:, ~is_first_step]
941
+ _, state = self._core(fc_feats, state)
942
+ new_state[0][:, is_first_step] = state[0][:, is_first_step]
943
+ new_state[1][:, is_first_step] = state[1][:, is_first_step]
944
+ state = new_state
945
+ # if (state[0]==0).all():
946
+ # # Let's forget about diverse beam search first
947
+ # _, state = self._core(fc_feats, state)
948
+ return self._core(xt, state)
949
+
950
+ def _prepare_feature(self, fc_feats, att_feats, att_masks):
951
+ fc_feats = self.fc_embed(fc_feats)
952
+
953
+ return fc_feats, att_feats, att_feats, att_masks
954
+
955
+
956
+ class LMModel(AttModel):
957
+ def __init__(self, opt):
958
+ super(LMModel, self).__init__(opt)
959
+ delattr(self, 'fc_embed')
960
+ self.fc_embed = lambda x: x.new_zeros(x.shape[0], self.input_encoding_size)
961
+ self.embed = nn.Embedding(self.vocab_size + 1, self.input_encoding_size)
962
+ self._core = LSTMCore(opt)
963
+ delattr(self, 'att_embed')
964
+ self.att_embed = lambda x : x
965
+ delattr(self, 'ctx2att')
966
+ self.ctx2att = lambda x: x
967
+
968
+ def core(self, xt, fc_feats, att_feats, p_att_feats, state, att_masks):
969
+ if (state[0]==0).all():
970
+ # Let's forget about diverse beam search first
971
+ _, state = self._core(fc_feats, state)
972
+ return self._core(xt, state)
973
+
974
+ def _prepare_feature(self, fc_feats, att_feats, att_masks):
975
+ fc_feats = self.fc_embed(fc_feats)
976
+
977
+ return fc_feats, None, None, None
captioning/models/BertCapModel.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ BertCapModel is using huggingface transformer bert model as seq2seq model.
3
+ The result is not as goog as original transformer.
4
+ """
5
+
6
+ from __future__ import absolute_import
7
+ from __future__ import division
8
+ from __future__ import print_function
9
+
10
+ import torch
11
+ import torch.nn as nn
12
+ import torch.nn.functional as F
13
+
14
+ import copy
15
+ import math
16
+ import numpy as np
17
+
18
+ from .CaptionModel import CaptionModel
19
+ from .AttModel import sort_pack_padded_sequence, pad_unsort_packed_sequence, pack_wrapper, AttModel
20
+ try:
21
+ from transformers import BertModel, BertConfig
22
+ except:
23
+ print('Hugginface transformers not installed; please visit https://github.com/huggingface/transformers')
24
+ from .TransformerModel import subsequent_mask, TransformerModel, Generator
25
+
26
+ class EncoderDecoder(nn.Module):
27
+ """
28
+ A standard Encoder-Decoder architecture. Base for this and many
29
+ other models.
30
+ """
31
+ def __init__(self, encoder, decoder, generator):
32
+ super(EncoderDecoder, self).__init__()
33
+ self.encoder = encoder
34
+ self.decoder = decoder
35
+ self.generator = generator
36
+
37
+ def forward(self, src, tgt, src_mask, tgt_mask):
38
+ "Take in and process masked src and target sequences."
39
+ return self.decode(self.encode(src, src_mask), src_mask,
40
+ tgt, tgt_mask)
41
+
42
+ def encode(self, src, src_mask):
43
+ return self.encoder(inputs_embeds=src,
44
+ attention_mask=src_mask)[0]
45
+
46
+ def decode(self, memory, src_mask, tgt, tgt_mask):
47
+ return self.decoder(input_ids=tgt,
48
+ attention_mask=tgt_mask,
49
+ encoder_hidden_states=memory,
50
+ encoder_attention_mask=src_mask)[0]
51
+
52
+
53
+ class BertCapModel(TransformerModel):
54
+
55
+ def make_model(self, src_vocab, tgt_vocab, N_enc=6, N_dec=6,
56
+ d_model=512, d_ff=2048, h=8, dropout=0.1):
57
+ "Helper: Construct a model from hyperparameters."
58
+ enc_config = BertConfig(vocab_size=1,
59
+ hidden_size=d_model,
60
+ num_hidden_layers=N_enc,
61
+ num_attention_heads=h,
62
+ intermediate_size=d_ff,
63
+ hidden_dropout_prob=dropout,
64
+ attention_probs_dropout_prob=dropout,
65
+ max_position_embeddings=1,
66
+ type_vocab_size=1)
67
+ dec_config = BertConfig(vocab_size=tgt_vocab,
68
+ hidden_size=d_model,
69
+ num_hidden_layers=N_dec,
70
+ num_attention_heads=h,
71
+ intermediate_size=d_ff,
72
+ hidden_dropout_prob=dropout,
73
+ attention_probs_dropout_prob=dropout,
74
+ max_position_embeddings=17,
75
+ type_vocab_size=1,
76
+ is_decoder=True)
77
+ encoder = BertModel(enc_config)
78
+ def return_embeds(*args, **kwargs):
79
+ return kwargs['inputs_embeds']
80
+ del encoder.embeddings; encoder.embeddings = return_embeds
81
+ decoder = BertModel(dec_config)
82
+ model = EncoderDecoder(
83
+ encoder,
84
+ decoder,
85
+ Generator(d_model, tgt_vocab))
86
+ return model
87
+
88
+ def __init__(self, opt):
89
+ super(BertCapModel, self).__init__(opt)
90
+
91
+ def core(self, it, fc_feats_ph, att_feats_ph, memory, state, mask):
92
+ """
93
+ state = [ys.unsqueeze(0)]
94
+ """
95
+ if len(state) == 0:
96
+ ys = it.unsqueeze(1)
97
+ else:
98
+ ys = torch.cat([state[0][0], it.unsqueeze(1)], dim=1)
99
+ out = self.model.decode(memory, mask,
100
+ ys,
101
+ subsequent_mask(ys.size(1))
102
+ .to(memory.device))
103
+ return out[:, -1], [ys.unsqueeze(0)]
captioning/models/CaptionModel.py ADDED
@@ -0,0 +1,411 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This file contains ShowAttendTell and AllImg model
2
+
3
+ # ShowAttendTell is from Show, Attend and Tell: Neural Image Caption Generation with Visual Attention
4
+ # https://arxiv.org/abs/1502.03044
5
+
6
+ # AllImg is a model where
7
+ # img feature is concatenated with word embedding at every time step as the input of lstm
8
+ from __future__ import absolute_import
9
+ from __future__ import division
10
+ from __future__ import print_function
11
+
12
+ import numpy as np
13
+ import torch
14
+ import torch.nn as nn
15
+ import torch.nn.functional as F
16
+ from torch.autograd import *
17
+ # from ..utils import misc as utils
18
+ from captioning.utils import misc as utils
19
+ from . import utils as model_utils
20
+
21
+ # torch.manual_seed(42)
22
+ # if torch.cuda.is_available():
23
+ # torch.cuda.manual_seed(42)
24
+
25
+ class CaptionModel(nn.Module):
26
+ def __init__(self):
27
+ super(CaptionModel, self).__init__()
28
+
29
+ # implements beam search
30
+ # calls beam_step and returns the final set of beams
31
+ # augments log-probabilities with diversity terms when number of groups > 1
32
+
33
+ def forward(self, *args, **kwargs):
34
+ mode = kwargs.get('mode', 'forward')
35
+ if 'mode' in kwargs:
36
+ del kwargs['mode']
37
+ return getattr(self, '_'+mode)(*args, **kwargs)
38
+
39
+ def beam_search(self, init_state, init_logprobs, *args, **kwargs):
40
+
41
+ # function computes the similarity score to be augmented
42
+ def add_diversity(beam_seq_table, logprobs, t, divm, diversity_lambda, bdash):
43
+ local_time = t - divm
44
+ unaug_logprobs = logprobs.clone()
45
+ batch_size = beam_seq_table[0].shape[0]
46
+
47
+ if divm > 0:
48
+ change = logprobs.new_zeros(batch_size, logprobs.shape[-1])
49
+ for prev_choice in range(divm):
50
+ prev_decisions = beam_seq_table[prev_choice][:, :, local_time] # Nxb
51
+ for prev_labels in range(bdash):
52
+ change.scatter_add_(1, prev_decisions[:, prev_labels].unsqueeze(-1), change.new_ones(batch_size, 1))
53
+
54
+ if local_time == 0:
55
+ logprobs = logprobs - change * diversity_lambda
56
+ else:
57
+ logprobs = logprobs - self.repeat_tensor(bdash, change) * diversity_lambda
58
+
59
+ return logprobs, unaug_logprobs
60
+
61
+
62
+ # does one step of classical beam search
63
+
64
+ def beam_step(logprobs, unaug_logprobs, beam_size, t, beam_seq, beam_seq_logprobs, beam_logprobs_sum, state):
65
+ #INPUTS:
66
+ #logprobs: probabilities augmented after diversity N*bxV
67
+ #beam_size: obvious
68
+ #t : time instant
69
+ #beam_seq : tensor contanining the beams
70
+ #beam_seq_logprobs: tensor contanining the beam logprobs
71
+ #beam_logprobs_sum: tensor contanining joint logprobs
72
+ #OUPUTS:
73
+ #beam_seq : tensor containing the word indices of the decoded captions Nxbxl
74
+ #beam_seq_logprobs : log-probability of each decision made, NxbxlxV
75
+ #beam_logprobs_sum : joint log-probability of each beam Nxb
76
+
77
+ batch_size = beam_logprobs_sum.shape[0]
78
+ vocab_size = logprobs.shape[-1]
79
+ logprobs = logprobs.reshape(batch_size, -1, vocab_size) # NxbxV
80
+ if t == 0:
81
+ assert logprobs.shape[1] == 1
82
+ beam_logprobs_sum = beam_logprobs_sum[:, :1]
83
+ candidate_logprobs = beam_logprobs_sum.unsqueeze(-1) + logprobs # beam_logprobs_sum Nxb logprobs is NxbxV
84
+ ys, ix = torch.sort(candidate_logprobs.reshape(candidate_logprobs.shape[0], -1), -1, True)
85
+ ys, ix = ys[:,:beam_size], ix[:,:beam_size]
86
+ beam_ix = ix // vocab_size # Nxb which beam
87
+ selected_ix = ix % vocab_size # Nxb # which world
88
+ state_ix = (beam_ix + torch.arange(batch_size).type_as(beam_ix).unsqueeze(-1) * logprobs.shape[1]).reshape(-1) # N*b which in Nxb beams
89
+
90
+
91
+ if t > 0:
92
+ # gather according to beam_ix
93
+ assert (beam_seq.gather(1, beam_ix.unsqueeze(-1).expand_as(beam_seq)) == beam_seq.reshape(-1, beam_seq.shape[-1])[state_ix].view_as(beam_seq)).all()
94
+ beam_seq = beam_seq.gather(1, beam_ix.unsqueeze(-1).expand_as(beam_seq))
95
+
96
+ beam_seq_logprobs = beam_seq_logprobs.gather(1, beam_ix.unsqueeze(-1).unsqueeze(-1).expand_as(beam_seq_logprobs))
97
+
98
+ beam_seq = torch.cat([beam_seq, selected_ix.unsqueeze(-1)], -1) # beam_seq Nxbxl
99
+ beam_logprobs_sum = beam_logprobs_sum.gather(1, beam_ix) + \
100
+ logprobs.reshape(batch_size, -1).gather(1, ix)
101
+ assert (beam_logprobs_sum == ys).all()
102
+ _tmp_beam_logprobs = unaug_logprobs[state_ix].reshape(batch_size, -1, vocab_size)
103
+ beam_logprobs = unaug_logprobs.reshape(batch_size, -1, vocab_size).gather(1, beam_ix.unsqueeze(-1).expand(-1, -1, vocab_size)) # NxbxV
104
+ assert (_tmp_beam_logprobs == beam_logprobs).all()
105
+ beam_seq_logprobs = torch.cat([
106
+ beam_seq_logprobs,
107
+ beam_logprobs.reshape(batch_size, -1, 1, vocab_size)], 2)
108
+
109
+ new_state = [None for _ in state]
110
+ for _ix in range(len(new_state)):
111
+ # copy over state in previous beam q to new beam at vix
112
+ new_state[_ix] = state[_ix][:, state_ix]
113
+ state = new_state
114
+ return beam_seq,beam_seq_logprobs,beam_logprobs_sum,state
115
+
116
+ # Start diverse_beam_search
117
+ opt = kwargs['opt']
118
+ temperature = opt.get('temperature', 1) # This should not affect beam search, but will affect dbs
119
+ beam_size = opt.get('beam_size', 10)
120
+ group_size = opt.get('group_size', 1)
121
+ diversity_lambda = opt.get('diversity_lambda', 0.5)
122
+ decoding_constraint = opt.get('decoding_constraint', 0)
123
+ remove_bad_endings = opt.get('remove_bad_endings', 1)
124
+ suppress_UNK = opt.get('suppress_UNK', 1)
125
+ length_penalty = utils.penalty_builder(opt.get('length_penalty', ''))
126
+ bdash = beam_size // group_size # beam per group
127
+
128
+ batch_size = init_logprobs.shape[0]
129
+ device = init_logprobs.device
130
+ # INITIALIZATIONS
131
+ beam_seq_table = [torch.LongTensor(batch_size, bdash, 0).to(device) for _ in range(group_size)]
132
+ beam_seq_logprobs_table = [torch.FloatTensor(batch_size, bdash, 0, self.vocab_size + 1).to(device) for _ in range(group_size)]
133
+ beam_logprobs_sum_table = [torch.zeros(batch_size, bdash).to(device) for _ in range(group_size)]
134
+
135
+ # logprobs # logprobs predicted in last time step, shape (beam_size, vocab_size+1)
136
+ done_beams_table = [[[] for __ in range(group_size)] for _ in range(batch_size)]
137
+ # state_table = [list(torch.unbind(_)) for _ in torch.stack(init_state).chunk(group_size, 2)]
138
+ # state_table = list(zip(*[_.reshape(-1, batch_size * bdash, group_size, *_.shape[2:]).chunk(group_size, 2) for _ in init_state]))
139
+ state_table = [[_.clone() for _ in init_state] for _ in range(group_size)]
140
+ # logprobs_table = list(init_logprobs.reshape(batch_size * bdash, group_size, -1).chunk(group_size, 0))
141
+ logprobs_table = [init_logprobs.clone() for _ in range(group_size)]
142
+ # END INIT
143
+
144
+ # Chunk elements in the args
145
+ args = list(args)
146
+ args = model_utils.split_tensors(group_size, args) # For each arg, turn (Bbg)x... to (Bb)x(g)x...
147
+ if self.__class__.__name__ == 'AttEnsemble':
148
+ args = [[[args[j][i][k] for i in range(len(self.models))] for j in range(len(args))] for k in range(group_size)] # group_name, arg_name, model_name
149
+ else:
150
+ args = [[args[i][j] for i in range(len(args))] for j in range(group_size)]
151
+
152
+ for t in range(self.seq_length + group_size - 1):
153
+ for divm in range(group_size):
154
+ if t >= divm and t <= self.seq_length + divm - 1:
155
+ # add diversity
156
+ logprobs = logprobs_table[divm]
157
+ # suppress previous word
158
+ if decoding_constraint and t-divm > 0:
159
+ logprobs.scatter_(1, beam_seq_table[divm][:, :, t-divm-1].reshape(-1, 1).to(device), float('-inf'))
160
+ if remove_bad_endings and t-divm > 0:
161
+ logprobs[torch.from_numpy(np.isin(beam_seq_table[divm][:, :, t-divm-1].cpu().numpy(), self.bad_endings_ix)).reshape(-1), 0] = float('-inf')
162
+ # suppress UNK tokens in the decoding
163
+ if suppress_UNK and hasattr(self, 'vocab') and self.vocab[str(logprobs.size(1)-1)] == 'UNK':
164
+ logprobs[:,logprobs.size(1)-1] = logprobs[:, logprobs.size(1)-1] - 1000
165
+ # diversity is added here
166
+ # the function directly modifies the logprobs values and hence, we need to return
167
+ # the unaugmented ones for sorting the candidates in the end. # for historical
168
+ # reasons :-)
169
+ logprobs, unaug_logprobs = add_diversity(beam_seq_table,logprobs,t,divm,diversity_lambda,bdash)
170
+
171
+ # infer new beams
172
+ beam_seq_table[divm],\
173
+ beam_seq_logprobs_table[divm],\
174
+ beam_logprobs_sum_table[divm],\
175
+ state_table[divm] = beam_step(logprobs,
176
+ unaug_logprobs,
177
+ bdash,
178
+ t-divm,
179
+ beam_seq_table[divm],
180
+ beam_seq_logprobs_table[divm],
181
+ beam_logprobs_sum_table[divm],
182
+ state_table[divm])
183
+
184
+ # if time's up... or if end token is reached then copy beams
185
+ for b in range(batch_size):
186
+ is_end = beam_seq_table[divm][b, :, t-divm] == self.eos_idx
187
+ assert beam_seq_table[divm].shape[-1] == t-divm+1
188
+ if t == self.seq_length + divm - 1:
189
+ is_end.fill_(1)
190
+ for vix in range(bdash):
191
+ if is_end[vix]:
192
+ final_beam = {
193
+ 'seq': beam_seq_table[divm][b, vix].clone(),
194
+ 'logps': beam_seq_logprobs_table[divm][b, vix].clone(),
195
+ 'unaug_p': beam_seq_logprobs_table[divm][b, vix].sum().item(),
196
+ 'p': beam_logprobs_sum_table[divm][b, vix].item()
197
+ }
198
+ final_beam['p'] = length_penalty(t-divm+1, final_beam['p'])
199
+ done_beams_table[b][divm].append(final_beam)
200
+ beam_logprobs_sum_table[divm][b, is_end] -= 1000
201
+
202
+ # move the current group one step forward in time
203
+
204
+ it = beam_seq_table[divm][:, :, t-divm].reshape(-1).to(logprobs.device)
205
+ logprobs_table[divm], state_table[divm] = self.get_logprobs_state(it, *(args[divm] + [state_table[divm]]))
206
+ logprobs_table[divm] = F.log_softmax(logprobs_table[divm] / temperature, dim=-1)
207
+
208
+ # all beams are sorted by their log-probabilities
209
+ done_beams_table = [[sorted(done_beams_table[b][i], key=lambda x: -x['p'])[:bdash] for i in range(group_size)] for b in range(batch_size)]
210
+ done_beams = [sum(_, []) for _ in done_beams_table]
211
+ return done_beams
212
+
213
+ def old_beam_search(self, init_state, init_logprobs, *args, **kwargs):
214
+
215
+ # function computes the similarity score to be augmented
216
+ def add_diversity(beam_seq_table, logprobsf, t, divm, diversity_lambda, bdash):
217
+ local_time = t - divm
218
+ unaug_logprobsf = logprobsf.clone()
219
+ for prev_choice in range(divm):
220
+ prev_decisions = beam_seq_table[prev_choice][local_time]
221
+ for sub_beam in range(bdash):
222
+ for prev_labels in range(bdash):
223
+ logprobsf[sub_beam][prev_decisions[prev_labels]] = logprobsf[sub_beam][prev_decisions[prev_labels]] - diversity_lambda
224
+ return unaug_logprobsf
225
+
226
+ # does one step of classical beam search
227
+
228
+ def beam_step(logprobsf, unaug_logprobsf, beam_size, t, beam_seq, beam_seq_logprobs, beam_logprobs_sum, state):
229
+ #INPUTS:
230
+ #logprobsf: probabilities augmented after diversity
231
+ #beam_size: obvious
232
+ #t : time instant
233
+ #beam_seq : tensor contanining the beams
234
+ #beam_seq_logprobs: tensor contanining the beam logprobs
235
+ #beam_logprobs_sum: tensor contanining joint logprobs
236
+ #OUPUTS:
237
+ #beam_seq : tensor containing the word indices of the decoded captions
238
+ #beam_seq_logprobs : log-probability of each decision made, same size as beam_seq
239
+ #beam_logprobs_sum : joint log-probability of each beam
240
+
241
+ ys,ix = torch.sort(logprobsf,1,True)
242
+ candidates = []
243
+ cols = min(beam_size, ys.size(1))
244
+ rows = beam_size
245
+ if t == 0:
246
+ rows = 1
247
+ for c in range(cols): # for each column (word, essentially)
248
+ for q in range(rows): # for each beam expansion
249
+ #compute logprob of expanding beam q with word in (sorted) position c
250
+ local_logprob = ys[q,c].item()
251
+ candidate_logprob = beam_logprobs_sum[q] + local_logprob
252
+ # local_unaug_logprob = unaug_logprobsf[q,ix[q,c]]
253
+ candidates.append({'c':ix[q,c], 'q':q, 'p':candidate_logprob, 'r':unaug_logprobsf[q]})
254
+ candidates = sorted(candidates, key=lambda x: -x['p'])
255
+
256
+ new_state = [_.clone() for _ in state]
257
+ #beam_seq_prev, beam_seq_logprobs_prev
258
+ if t >= 1:
259
+ #we''ll need these as reference when we fork beams around
260
+ beam_seq_prev = beam_seq[:t].clone()
261
+ beam_seq_logprobs_prev = beam_seq_logprobs[:t].clone()
262
+ for vix in range(beam_size):
263
+ v = candidates[vix]
264
+ #fork beam index q into index vix
265
+ if t >= 1:
266
+ beam_seq[:t, vix] = beam_seq_prev[:, v['q']]
267
+ beam_seq_logprobs[:t, vix] = beam_seq_logprobs_prev[:, v['q']]
268
+ #rearrange recurrent states
269
+ for state_ix in range(len(new_state)):
270
+ # copy over state in previous beam q to new beam at vix
271
+ new_state[state_ix][:, vix] = state[state_ix][:, v['q']] # dimension one is time step
272
+ #append new end terminal at the end of this beam
273
+ beam_seq[t, vix] = v['c'] # c'th word is the continuation
274
+ beam_seq_logprobs[t, vix] = v['r'] # the raw logprob here
275
+ beam_logprobs_sum[vix] = v['p'] # the new (sum) logprob along this beam
276
+ state = new_state
277
+ return beam_seq,beam_seq_logprobs,beam_logprobs_sum,state,candidates
278
+
279
+ # Start diverse_beam_search
280
+ opt = kwargs['opt']
281
+ temperature = opt.get('temperature', 1) # This should not affect beam search, but will affect dbs
282
+ beam_size = opt.get('beam_size', 10)
283
+ group_size = opt.get('group_size', 1)
284
+ diversity_lambda = opt.get('diversity_lambda', 0.5)
285
+ decoding_constraint = opt.get('decoding_constraint', 0)
286
+ remove_bad_endings = opt.get('remove_bad_endings', 1)
287
+ suppress_UNK = opt.get('suppress_UNK', 1)
288
+ length_penalty = utils.penalty_builder(opt.get('length_penalty', ''))
289
+ bdash = beam_size // group_size # beam per group
290
+
291
+ # INITIALIZATIONS
292
+ beam_seq_table = [torch.LongTensor(self.seq_length, bdash).zero_() for _ in range(group_size)]
293
+ beam_seq_logprobs_table = [torch.FloatTensor(self.seq_length, bdash, self.vocab_size + 1).zero_() for _ in range(group_size)]
294
+ beam_logprobs_sum_table = [torch.zeros(bdash) for _ in range(group_size)]
295
+
296
+ # logprobs # logprobs predicted in last time step, shape (beam_size, vocab_size+1)
297
+ done_beams_table = [[] for _ in range(group_size)]
298
+ # state_table = [list(torch.unbind(_)) for _ in torch.stack(init_state).chunk(group_size, 2)]
299
+ state_table = list(zip(*[_.chunk(group_size, 1) for _ in init_state]))
300
+ logprobs_table = list(init_logprobs.chunk(group_size, 0))
301
+ # END INIT
302
+
303
+ # Chunk elements in the args
304
+ args = list(args)
305
+ if self.__class__.__name__ == 'AttEnsemble':
306
+ args = [[_.chunk(group_size) if _ is not None else [None]*group_size for _ in args_] for args_ in args] # arg_name, model_name, group_name
307
+ args = [[[args[j][i][k] for i in range(len(self.models))] for j in range(len(args))] for k in range(group_size)] # group_name, arg_name, model_name
308
+ else:
309
+ args = [_.chunk(group_size) if _ is not None else [None]*group_size for _ in args]
310
+ args = [[args[i][j] for i in range(len(args))] for j in range(group_size)]
311
+
312
+ for t in range(self.seq_length + group_size - 1):
313
+ for divm in range(group_size):
314
+ if t >= divm and t <= self.seq_length + divm - 1:
315
+ # add diversity
316
+ logprobsf = logprobs_table[divm]
317
+ # suppress previous word
318
+ if decoding_constraint and t-divm > 0:
319
+ logprobsf.scatter_(1, beam_seq_table[divm][t-divm-1].unsqueeze(1).to(logprobsf.device), float('-inf'))
320
+ if remove_bad_endings and t-divm > 0:
321
+ logprobsf[torch.from_numpy(np.isin(beam_seq_table[divm][t-divm-1].cpu().numpy(), self.bad_endings_ix)), 0] = float('-inf')
322
+ # suppress UNK tokens in the decoding
323
+ if suppress_UNK and hasattr(self, 'vocab') and self.vocab[str(logprobsf.size(1)-1)] == 'UNK':
324
+ logprobsf[:,logprobsf.size(1)-1] = logprobsf[:, logprobsf.size(1)-1] - 1000
325
+ # diversity is added here
326
+ # the function directly modifies the logprobsf values and hence, we need to return
327
+ # the unaugmented ones for sorting the candidates in the end. # for historical
328
+ # reasons :-)
329
+ unaug_logprobsf = add_diversity(beam_seq_table,logprobsf,t,divm,diversity_lambda,bdash)
330
+
331
+ # infer new beams
332
+ beam_seq_table[divm],\
333
+ beam_seq_logprobs_table[divm],\
334
+ beam_logprobs_sum_table[divm],\
335
+ state_table[divm],\
336
+ candidates_divm = beam_step(logprobsf,
337
+ unaug_logprobsf,
338
+ bdash,
339
+ t-divm,
340
+ beam_seq_table[divm],
341
+ beam_seq_logprobs_table[divm],
342
+ beam_logprobs_sum_table[divm],
343
+ state_table[divm])
344
+
345
+ # if time's up... or if end token is reached then copy beams
346
+ for vix in range(bdash):
347
+ if beam_seq_table[divm][t-divm,vix] == self.eos_idx or t == self.seq_length + divm - 1:
348
+ final_beam = {
349
+ 'seq': beam_seq_table[divm][:, vix].clone(),
350
+ 'logps': beam_seq_logprobs_table[divm][:, vix].clone(),
351
+ 'unaug_p': beam_seq_logprobs_table[divm][:, vix].sum().item(),
352
+ 'p': beam_logprobs_sum_table[divm][vix].item()
353
+ }
354
+ final_beam['p'] = length_penalty(t-divm+1, final_beam['p'])
355
+ done_beams_table[divm].append(final_beam)
356
+ # don't continue beams from finished sequences
357
+ beam_logprobs_sum_table[divm][vix] = -1000
358
+
359
+ # move the current group one step forward in time
360
+
361
+ it = beam_seq_table[divm][t-divm].to(logprobsf.device)
362
+ logprobs_table[divm], state_table[divm] = self.get_logprobs_state(it, *(args[divm] + [state_table[divm]]))
363
+ logprobs_table[divm] = F.log_softmax(logprobs_table[divm] / temperature, dim=-1)
364
+
365
+ # all beams are sorted by their log-probabilities
366
+ done_beams_table = [sorted(done_beams_table[i], key=lambda x: -x['p'])[:bdash] for i in range(group_size)]
367
+ done_beams = sum(done_beams_table, [])
368
+ return done_beams
369
+
370
+ def sample_next_word(self, logprobs, sample_method, temperature):
371
+ if sample_method == 'greedy':
372
+ sampleLogprobs, it = torch.max(logprobs.data, 1)
373
+ it = it.view(-1).long()
374
+ elif sample_method == 'gumbel': # gumbel softmax
375
+ # ref: https://gist.github.com/yzh119/fd2146d2aeb329d067568a493b20172f
376
+ def sample_gumbel(shape, eps=1e-20):
377
+ U = torch.rand(shape).to(logprobs.device)
378
+ return -torch.log(-torch.log(U + eps) + eps)
379
+ def gumbel_softmax_sample(logits, temperature):
380
+ y = logits + sample_gumbel(logits.size())
381
+ return F.log_softmax(y / temperature, dim=-1)
382
+ _logprobs = gumbel_softmax_sample(logprobs, temperature)
383
+ _, it = torch.max(_logprobs.data, 1)
384
+ sampleLogprobs = logprobs.gather(1, it.unsqueeze(1)) # gather the logprobs at sampled positions
385
+ else:
386
+ logprobs = logprobs / temperature
387
+ if sample_method.startswith('top'): # topk sampling
388
+ top_num = float(sample_method[3:])
389
+ if 0 < top_num < 1:
390
+ # nucleus sampling from # The Curious Case of Neural Text Degeneration
391
+ probs = F.softmax(logprobs, dim=1)
392
+ sorted_probs, sorted_indices = torch.sort(probs, descending=True, dim=1)
393
+ _cumsum = sorted_probs.cumsum(1)
394
+ mask = _cumsum < top_num
395
+ mask = torch.cat([torch.ones_like(mask[:,:1]), mask[:,:-1]], 1)
396
+ sorted_probs = sorted_probs * mask.to(sorted_probs)
397
+ sorted_probs = sorted_probs / sorted_probs.sum(1, keepdim=True)
398
+ logprobs.scatter_(1, sorted_indices, sorted_probs.log())
399
+ else:
400
+ the_k = int(top_num)
401
+ tmp = torch.empty_like(logprobs).fill_(float('-inf'))
402
+ topk, indices = torch.topk(logprobs, the_k, dim=1)
403
+ tmp = tmp.scatter(1, indices, topk)
404
+ logprobs = tmp
405
+ it = torch.distributions.Categorical(logits=logprobs.detach()).sample()
406
+ sampleLogprobs = logprobs.gather(1, it.unsqueeze(1)) # gather the logprobs at sampled positions
407
+ return it, sampleLogprobs
408
+
409
+
410
+ def decode_sequence(self, seq):
411
+ return utils.decode_sequence(self.vocab, seq)
captioning/models/FCModel.py ADDED
@@ -0,0 +1,204 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import absolute_import
2
+ from __future__ import division
3
+ from __future__ import print_function
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+ from torch.autograd import *
9
+ from . import utils
10
+
11
+ from .CaptionModel import CaptionModel
12
+
13
+ class LSTMCore(nn.Module):
14
+ def __init__(self, opt):
15
+ super(LSTMCore, self).__init__()
16
+ self.input_encoding_size = opt.input_encoding_size
17
+ self.rnn_size = opt.rnn_size
18
+ self.drop_prob_lm = opt.drop_prob_lm
19
+
20
+ # Build a LSTM
21
+ self.i2h = nn.Linear(self.input_encoding_size, 5 * self.rnn_size)
22
+ self.h2h = nn.Linear(self.rnn_size, 5 * self.rnn_size)
23
+ self.dropout = nn.Dropout(self.drop_prob_lm)
24
+
25
+ def forward(self, xt, state):
26
+
27
+ all_input_sums = self.i2h(xt) + self.h2h(state[0][-1])
28
+ sigmoid_chunk = all_input_sums.narrow(1, 0, 3 * self.rnn_size)
29
+ sigmoid_chunk = torch.sigmoid(sigmoid_chunk)
30
+ in_gate = sigmoid_chunk.narrow(1, 0, self.rnn_size)
31
+ forget_gate = sigmoid_chunk.narrow(1, self.rnn_size, self.rnn_size)
32
+ out_gate = sigmoid_chunk.narrow(1, self.rnn_size * 2, self.rnn_size)
33
+
34
+ in_transform = torch.max(\
35
+ all_input_sums.narrow(1, 3 * self.rnn_size, self.rnn_size),
36
+ all_input_sums.narrow(1, 4 * self.rnn_size, self.rnn_size))
37
+ next_c = forget_gate * state[1][-1] + in_gate * in_transform
38
+ next_h = out_gate * torch.tanh(next_c)
39
+
40
+ output = self.dropout(next_h)
41
+ state = (next_h.unsqueeze(0), next_c.unsqueeze(0))
42
+ return output, state
43
+
44
+ class FCModel(CaptionModel):
45
+ def __init__(self, opt):
46
+ super(FCModel, self).__init__()
47
+ self.vocab_size = opt.vocab_size
48
+ self.input_encoding_size = opt.input_encoding_size
49
+ self.rnn_type = opt.rnn_type
50
+ self.rnn_size = opt.rnn_size
51
+ self.num_layers = opt.num_layers
52
+ self.drop_prob_lm = opt.drop_prob_lm
53
+ self.seq_length = opt.seq_length
54
+ self.fc_feat_size = opt.fc_feat_size
55
+
56
+ self.ss_prob = 0.0 # Schedule sampling probability
57
+
58
+ self.img_embed = nn.Linear(self.fc_feat_size, self.input_encoding_size)
59
+ self.core = LSTMCore(opt)
60
+ self.embed = nn.Embedding(self.vocab_size + 1, self.input_encoding_size)
61
+ self.logit = nn.Linear(self.rnn_size, self.vocab_size + 1)
62
+
63
+ self.init_weights()
64
+
65
+ def init_weights(self):
66
+ initrange = 0.1
67
+ self.embed.weight.data.uniform_(-initrange, initrange)
68
+ self.logit.bias.data.fill_(0)
69
+ self.logit.weight.data.uniform_(-initrange, initrange)
70
+
71
+ def init_hidden(self, bsz):
72
+ weight = self.logit.weight
73
+ if self.rnn_type == 'lstm':
74
+ return (weight.new_zeros(self.num_layers, bsz, self.rnn_size),
75
+ weight.new_zeros(self.num_layers, bsz, self.rnn_size))
76
+ else:
77
+ return weight.new_zeros(self.num_layers, bsz, self.rnn_size)
78
+
79
+ def _forward(self, fc_feats, att_feats, seq, att_masks=None):
80
+ batch_size = fc_feats.size(0)
81
+ seq_per_img = seq.shape[0] // batch_size
82
+ state = self.init_hidden(batch_size*seq_per_img)
83
+ outputs = []
84
+
85
+ if seq_per_img > 1:
86
+ fc_feats = utils.repeat_tensors(seq_per_img, fc_feats)
87
+
88
+ for i in range(seq.size(1) + 1):
89
+ if i == 0:
90
+ xt = self.img_embed(fc_feats)
91
+ else:
92
+ if self.training and i >= 2 and self.ss_prob > 0.0: # otherwiste no need to sample
93
+ sample_prob = fc_feats.data.new(batch_size*seq_per_img).uniform_(0, 1)
94
+ sample_mask = sample_prob < self.ss_prob
95
+ if sample_mask.sum() == 0:
96
+ it = seq[:, i-1].clone()
97
+ else:
98
+ sample_ind = sample_mask.nonzero().view(-1)
99
+ it = seq[:, i-1].data.clone()
100
+ #prob_prev = torch.exp(outputs[-1].data.index_select(0, sample_ind)) # fetch prev distribution: shape Nx(M+1)
101
+ #it.index_copy_(0, sample_ind, torch.multinomial(prob_prev, 1).view(-1))
102
+ prob_prev = torch.exp(outputs[-1].data) # fetch prev distribution: shape Nx(M+1)
103
+ it.index_copy_(0, sample_ind, torch.multinomial(prob_prev, 1).view(-1).index_select(0, sample_ind))
104
+ else:
105
+ it = seq[:, i-1].clone()
106
+ # break if all the sequences end
107
+ if i >= 2 and seq[:, i-1].sum() == 0:
108
+ break
109
+ xt = self.embed(it)
110
+
111
+ output, state = self.core(xt, state)
112
+ output = F.log_softmax(self.logit(output), dim=1)
113
+ outputs.append(output)
114
+
115
+ return torch.cat([_.unsqueeze(1) for _ in outputs[1:]], 1).contiguous()
116
+
117
+ def get_logprobs_state(self, it, state):
118
+ # 'it' is contains a word index
119
+ xt = self.embed(it)
120
+
121
+ output, state = self.core(xt, state)
122
+ logprobs = F.log_softmax(self.logit(output), dim=1)
123
+
124
+ return logprobs, state
125
+
126
+ def _sample_beam(self, fc_feats, att_feats, att_masks=None, opt={}):
127
+ beam_size = opt.get('beam_size', 10)
128
+ batch_size = fc_feats.size(0)
129
+
130
+ assert beam_size <= self.vocab_size + 1, 'lets assume this for now, otherwise this corner case causes a few headaches down the road. can be dealt with in future if needed'
131
+ seq = torch.LongTensor(self.seq_length, batch_size).zero_()
132
+ seqLogprobs = torch.FloatTensor(self.seq_length, batch_size, self.vocab_size + 1)
133
+ # lets process every image independently for now, for simplicity
134
+
135
+ self.done_beams = [[] for _ in range(batch_size)]
136
+ for k in range(batch_size):
137
+ state = self.init_hidden(beam_size)
138
+ for t in range(2):
139
+ if t == 0:
140
+ xt = self.img_embed(fc_feats[k:k+1]).expand(beam_size, self.input_encoding_size)
141
+ elif t == 1: # input <bos>
142
+ it = fc_feats.data.new(beam_size).long().zero_()
143
+ xt = self.embed(it)
144
+
145
+ output, state = self.core(xt, state)
146
+ logprobs = F.log_softmax(self.logit(output), dim=1)
147
+
148
+ self.done_beams[k] = self.beam_search(state, logprobs, opt=opt)
149
+ seq[:, k] = self.done_beams[k][0]['seq'] # the first beam has highest cumulative score
150
+ seqLogprobs[:, k] = self.done_beams[k][0]['logps']
151
+ # return the samples and their log likelihoods
152
+ return seq.transpose(0, 1), seqLogprobs.transpose(0, 1)
153
+
154
+ def _sample(self, fc_feats, att_feats, att_masks=None, opt={}):
155
+ sample_method = opt.get('sample_method', 'greedy')
156
+ beam_size = opt.get('beam_size', 1)
157
+ temperature = opt.get('temperature', 1.0)
158
+ if beam_size > 1 and sample_method in ['greedy', 'beam_search']:
159
+ return self._sample_beam(fc_feats, att_feats, opt)
160
+
161
+ batch_size = fc_feats.size(0)
162
+ state = self.init_hidden(batch_size)
163
+ seq = fc_feats.new_zeros(batch_size, self.seq_length, dtype=torch.long)
164
+ seqLogprobs = fc_feats.new_zeros(batch_size, self.seq_length, self.vocab_size + 1)
165
+ for t in range(self.seq_length + 2):
166
+ if t == 0:
167
+ xt = self.img_embed(fc_feats)
168
+ else:
169
+ if t == 1: # input <bos>
170
+ it = fc_feats.data.new(batch_size).long().zero_()
171
+ xt = self.embed(it)
172
+
173
+ output, state = self.core(xt, state)
174
+ logprobs = F.log_softmax(self.logit(output), dim=1)
175
+
176
+ # sample the next_word
177
+ if t == self.seq_length + 1: # skip if we achieve maximum length
178
+ break
179
+ if sample_method == 'greedy':
180
+ sampleLogprobs, it = torch.max(logprobs.data, 1)
181
+ it = it.view(-1).long()
182
+ else:
183
+ if temperature == 1.0:
184
+ prob_prev = torch.exp(logprobs.data).cpu() # fetch prev distribution: shape Nx(M+1)
185
+ else:
186
+ # scale logprobs by temperature
187
+ prob_prev = torch.exp(torch.div(logprobs.data, temperature)).cpu()
188
+ it = torch.multinomial(prob_prev, 1).to(logprobs.device)
189
+ sampleLogprobs = logprobs.gather(1, it) # gather the logprobs at sampled positions
190
+ it = it.view(-1).long() # and flatten indices for downstream processing
191
+
192
+ if t >= 1:
193
+ # stop when all finished
194
+ if t == 1:
195
+ unfinished = it > 0
196
+ else:
197
+ unfinished = unfinished & (it > 0)
198
+ it = it * unfinished.type_as(it)
199
+ seq[:,t-1] = it #seq[t] the input of t+2 time step
200
+ seqLogprobs[:,t-1] = sampleLogprobs.view(-1)
201
+ if unfinished.sum() == 0:
202
+ break
203
+
204
+ return seq, seqLogprobs
captioning/models/M2Transformer.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Instruction to use meshed_memory_transformer (https://arxiv.org/abs/1912.08226)
3
+
4
+ pip install git+https://github.com/ruotianluo/meshed-memory-transformer.git
5
+
6
+ Note:
7
+ Currently m2transformer is not performing as well as original transformer. Not sure why? Still investigating.
8
+ """
9
+
10
+ from __future__ import absolute_import
11
+ from __future__ import division
12
+ from __future__ import print_function
13
+
14
+ import torch
15
+ import torch.nn as nn
16
+ import torch.nn.functional as F
17
+
18
+ import copy
19
+ import math
20
+ import numpy as np
21
+
22
+ from .CaptionModel import CaptionModel
23
+ from .AttModel import sort_pack_padded_sequence, pad_unsort_packed_sequence, pack_wrapper, AttModel
24
+
25
+ try:
26
+ from m2transformer.models.transformer import Transformer, MemoryAugmentedEncoder, MeshedDecoder, ScaledDotProductAttentionMemory
27
+ except:
28
+ print('meshed-memory-transformer not installed; please run `pip install git+https://github.com/ruotianluo/meshed-memory-transformer.git`')
29
+ from .TransformerModel import subsequent_mask, TransformerModel
30
+
31
+
32
+ class M2TransformerModel(TransformerModel):
33
+
34
+ def make_model(self, src_vocab, tgt_vocab, N_enc=6, N_dec=6,
35
+ d_model=512, d_ff=2048, h=8, dropout=0.1):
36
+ "Helper: Construct a model from hyperparameters."
37
+ encoder = MemoryAugmentedEncoder(N_enc, 0, attention_module=ScaledDotProductAttentionMemory,
38
+ attention_module_kwargs={'m': 40})
39
+ # Another implementation is to use MultiLevelEncoder + att_embed
40
+ decoder = MeshedDecoder(tgt_vocab, 54, N_dec, -1) # -1 is padding;
41
+ model = Transformer(0, encoder, decoder) # 0 is bos
42
+ return model
43
+
44
+ def __init__(self, opt):
45
+ super(M2TransformerModel, self).__init__(opt)
46
+ delattr(self, 'att_embed')
47
+ self.att_embed = lambda x: x # The visual embed is in the MAEncoder
48
+ # Notes: The dropout in MAEncoder is different from my att_embed, mine is 0.5?
49
+ # Also the attention mask seems wrong in MAEncoder too...intersting
50
+
51
+ def logit(self, x): # unsafe way
52
+ return x # M2transformer always output logsoftmax
53
+
54
+ def _prepare_feature(self, fc_feats, att_feats, att_masks):
55
+
56
+ att_feats, seq, att_masks, seq_mask = self._prepare_feature_forward(att_feats, att_masks)
57
+ memory, att_masks = self.model.encoder(att_feats)
58
+
59
+ return fc_feats[...,:0], att_feats[...,:0], memory, att_masks
60
+
61
+ def _forward(self, fc_feats, att_feats, seq, att_masks=None):
62
+ if seq.ndim == 3: # B * seq_per_img * seq_len
63
+ seq = seq.reshape(-1, seq.shape[2])
64
+ att_feats, seq, att_masks, seq_mask = self._prepare_feature_forward(att_feats, att_masks, seq)
65
+
66
+ seq = seq.clone()
67
+ seq[~seq_mask.any(-2)] = -1 # Make padding to be -1 (my dataloader uses 0 as padding)
68
+ outputs = self.model(att_feats, seq)
69
+
70
+ # out = self.model(att_feats, seq, att_masks, seq_mask)
71
+
72
+ # outputs = self.model.generator(out)
73
+
74
+ return outputs
75
+
76
+ def core(self, it, fc_feats_ph, att_feats_ph, memory, state, mask):
77
+ """
78
+ state = [ys.unsqueeze(0)]
79
+ """
80
+ if len(state) == 0:
81
+ ys = it.unsqueeze(1)
82
+ else:
83
+ ys = torch.cat([state[0][0], it.unsqueeze(1)], dim=1)
84
+ out = self.model.decoder(ys, memory, mask)
85
+ return out[:, -1], [ys.unsqueeze(0)]
86
+
87
+ def _sample_beam(self, fc_feats, att_feats, att_masks=None, opt={}):
88
+ beam_size = opt.get('beam_size', 10)
89
+ group_size = opt.get('group_size', 1)
90
+ sample_n = opt.get('sample_n', 10)
91
+ assert sample_n == 1 or sample_n == beam_size // group_size, 'when beam search, sample_n == 1 or beam search'
92
+
93
+ att_feats, _, __, ___ = self._prepare_feature_forward(att_feats, att_masks)
94
+ seq, logprobs, seqLogprobs = self.model.beam_search(att_feats, self.seq_length, 0,
95
+ beam_size, return_probs=True, out_size=beam_size)
96
+ seq = seq.reshape(-1, *seq.shape[2:])
97
+ seqLogprobs = seqLogprobs.reshape(-1, *seqLogprobs.shape[2:])
98
+
99
+ # if not (seqLogprobs.gather(-1, seq.unsqueeze(-1)).squeeze(-1) == logprobs.reshape(-1, logprobs.shape[-1])).all():
100
+ # import pudb;pu.db
101
+ # seqLogprobs = logprobs.reshape(-1, logprobs.shape[-1]).unsqueeze(-1).expand(-1,-1,seqLogprobs.shape[-1])
102
+ return seq, seqLogprobs
captioning/models/OldModel.py ADDED
@@ -0,0 +1,265 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This file contains ShowAttendTell and AllImg model
2
+
3
+ # ShowAttendTell is from Show, Attend and Tell: Neural Image Caption Generation with Visual Attention
4
+ # https://arxiv.org/abs/1502.03044
5
+
6
+ # AllImg is a model where
7
+ # img feature is concatenated with word embedding at every time step as the input of lstm
8
+ from __future__ import absolute_import
9
+ from __future__ import division
10
+ from __future__ import print_function
11
+
12
+ import torch
13
+ import torch.nn as nn
14
+ import torch.nn.functional as F
15
+ from torch.autograd import *
16
+ # import misc.utils as utils
17
+ # import utils as utils
18
+ from . import utils
19
+
20
+ from .CaptionModel import CaptionModel
21
+
22
+
23
+ class OldModel(CaptionModel):
24
+ def __init__(self, opt):
25
+ super(OldModel, self).__init__()
26
+ self.vocab_size = opt.vocab_size
27
+ self.input_encoding_size = opt.input_encoding_size
28
+ self.rnn_type = opt.rnn_type
29
+ self.rnn_size = opt.rnn_size
30
+ self.num_layers = opt.num_layers
31
+ self.drop_prob_lm = opt.drop_prob_lm
32
+ self.seq_length = opt.seq_length
33
+ self.fc_feat_size = opt.fc_feat_size
34
+ self.att_feat_size = opt.att_feat_size
35
+
36
+ self.ss_prob = 0.0 # Schedule sampling probability
37
+
38
+ self.linear = nn.Linear(self.fc_feat_size, self.num_layers * self.rnn_size) # feature to rnn_size
39
+ self.embed = nn.Embedding(self.vocab_size + 1, self.input_encoding_size)
40
+ self.logit = nn.Linear(self.rnn_size, self.vocab_size + 1)
41
+ self.dropout = nn.Dropout(self.drop_prob_lm)
42
+
43
+ self.init_weights()
44
+
45
+ def init_weights(self):
46
+ initrange = 0.1
47
+ self.embed.weight.data.uniform_(-initrange, initrange)
48
+ self.logit.bias.data.fill_(0)
49
+ self.logit.weight.data.uniform_(-initrange, initrange)
50
+
51
+ def init_hidden(self, fc_feats):
52
+ image_map = self.linear(fc_feats).view(-1, self.num_layers, self.rnn_size).transpose(0, 1)
53
+ if self.rnn_type == 'lstm':
54
+ return (image_map, image_map)
55
+ else:
56
+ return image_map
57
+
58
+ def forward(self, fc_feats, att_feats, seq):
59
+ batch_size = fc_feats.size(0)
60
+ state = self.init_hidden(fc_feats)
61
+
62
+ outputs = []
63
+
64
+ for i in range(seq.size(1) - 1):
65
+ if self.training and i >= 1 and self.ss_prob > 0.0: # otherwiste no need to sample
66
+ sample_prob = fc_feats.data.new(batch_size).uniform_(0, 1)
67
+ sample_mask = sample_prob < self.ss_prob
68
+ if sample_mask.sum() == 0:
69
+ it = seq[:, i].clone()
70
+ else:
71
+ sample_ind = sample_mask.nonzero().view(-1)
72
+ it = seq[:, i].data.clone()
73
+ # prob_prev = torch.exp(outputs[-1].data.index_select(0, sample_ind)) # fetch prev distribution: shape Nx(M+1)
74
+ # it.index_copy_(0, sample_ind, torch.multinomial(prob_prev, 1).view(-1))
75
+ prob_prev = torch.exp(outputs[-1].data) # fetch prev distribution: shape Nx(M+1)
76
+ it.index_copy_(0, sample_ind, torch.multinomial(prob_prev, 1).view(-1).index_select(0, sample_ind))
77
+ it = Variable(it, requires_grad=False)
78
+ else:
79
+ it = seq[:, i].clone()
80
+ # break if all the sequences end
81
+ if i >= 1 and seq[:, i].data.sum() == 0:
82
+ break
83
+
84
+ xt = self.embed(it)
85
+
86
+ output, state = self.core(xt, fc_feats, att_feats, state)
87
+ output = F.log_softmax(self.logit(self.dropout(output)))
88
+ outputs.append(output)
89
+
90
+ return torch.cat([_.unsqueeze(1) for _ in outputs], 1)
91
+
92
+ def get_logprobs_state(self, it, tmp_fc_feats, tmp_att_feats, state):
93
+ # 'it' is Variable contraining a word index
94
+ xt = self.embed(it)
95
+
96
+ output, state = self.core(xt, tmp_fc_feats, tmp_att_feats, state)
97
+ logprobs = F.log_softmax(self.logit(self.dropout(output)))
98
+
99
+ return logprobs, state
100
+
101
+ def sample_beam(self, fc_feats, att_feats, opt={}):
102
+ beam_size = opt.get('beam_size', 10)
103
+ batch_size = fc_feats.size(0)
104
+
105
+ assert beam_size <= self.vocab_size + 1, 'lets assume this for now, otherwise this corner case causes a few headaches down the road. can be dealt with in future if needed'
106
+ seq = torch.LongTensor(self.seq_length, batch_size).zero_()
107
+ seqLogprobs = torch.FloatTensor(self.seq_length, batch_size)
108
+ # lets process every image independently for now, for simplicity
109
+
110
+ self.done_beams = [[] for _ in range(batch_size)]
111
+ for k in range(batch_size):
112
+ tmp_fc_feats = fc_feats[k:k + 1].expand(beam_size, self.fc_feat_size)
113
+ tmp_att_feats = att_feats[k:k + 1].expand(*((beam_size,) + att_feats.size()[1:])).contiguous()
114
+
115
+ state = self.init_hidden(tmp_fc_feats)
116
+
117
+ beam_seq = torch.LongTensor(self.seq_length, beam_size).zero_()
118
+ beam_seq_logprobs = torch.FloatTensor(self.seq_length, beam_size).zero_()
119
+ beam_logprobs_sum = torch.zeros(beam_size) # running sum of logprobs for each beam
120
+ done_beams = []
121
+ for t in range(1):
122
+ if t == 0: # input <bos>
123
+ it = fc_feats.data.new(beam_size).long().zero_()
124
+ xt = self.embed(Variable(it, requires_grad=False))
125
+
126
+ output, state = self.core(xt, tmp_fc_feats, tmp_att_feats, state)
127
+ logprobs = F.log_softmax(self.logit(self.dropout(output)))
128
+
129
+ self.done_beams[k] = self.beam_search(state, logprobs, tmp_fc_feats, tmp_att_feats, opt=opt)
130
+ seq[:, k] = self.done_beams[k][0]['seq'] # the first beam has highest cumulative score
131
+ seqLogprobs[:, k] = self.done_beams[k][0]['logps']
132
+ # return the samples and their log likelihoods
133
+ return seq.transpose(0, 1), seqLogprobs.transpose(0, 1)
134
+
135
+ def sample(self, fc_feats, att_feats, opt={}):
136
+ sample_max = opt.get('sample_max', 1)
137
+ beam_size = opt.get('beam_size', 1)
138
+ temperature = opt.get('temperature', 1.0)
139
+ if beam_size > 1:
140
+ return self.sample_beam(fc_feats, att_feats, opt)
141
+
142
+ batch_size = fc_feats.size(0)
143
+ state = self.init_hidden(fc_feats)
144
+
145
+ seq = []
146
+ seqLogprobs = []
147
+ for t in range(self.seq_length + 1):
148
+ if t == 0: # input <bos>
149
+ it = fc_feats.data.new(batch_size).long().zero_()
150
+ elif sample_max:
151
+ sampleLogprobs, it = torch.max(logprobs.data, 1)
152
+ it = it.view(-1).long()
153
+ else:
154
+ if temperature == 1.0:
155
+ prob_prev = torch.exp(logprobs.data).cpu() # fetch prev distribution: shape Nx(M+1)
156
+ else:
157
+ # scale logprobs by temperature
158
+ prob_prev = torch.exp(torch.div(logprobs.data, temperature)).cpu()
159
+ it = torch.multinomial(prob_prev, 1).cuda()
160
+ sampleLogprobs = logprobs.gather(1, Variable(it,
161
+ requires_grad=False)) # gather the logprobs at sampled positions
162
+ it = it.view(-1).long() # and flatten indices for downstream processing
163
+
164
+ xt = self.embed(Variable(it, requires_grad=False))
165
+
166
+ if t >= 1:
167
+ # stop when all finished
168
+ if t == 1:
169
+ unfinished = it > 0
170
+ else:
171
+ unfinished = unfinished * (it > 0)
172
+ if unfinished.sum() == 0:
173
+ break
174
+ it = it * unfinished.type_as(it)
175
+ seq.append(it) # seq[t] the input of t+2 time step
176
+ seqLogprobs.append(sampleLogprobs.view(-1))
177
+
178
+ output, state = self.core(xt, fc_feats, att_feats, state)
179
+ logprobs = F.log_softmax(self.logit(self.dropout(output)), -1)
180
+
181
+ return torch.cat([_.unsqueeze(1) for _ in seq], 1), torch.cat([_.unsqueeze(1) for _ in seqLogprobs], 1)
182
+
183
+
184
+ class ShowAttendTellCore(nn.Module):
185
+ def __init__(self, opt):
186
+ super(ShowAttendTellCore, self).__init__()
187
+ self.input_encoding_size = opt.input_encoding_size
188
+ self.rnn_type = opt.rnn_type
189
+ self.rnn_size = opt.rnn_size
190
+ self.num_layers = opt.num_layers
191
+ self.drop_prob_lm = opt.drop_prob_lm
192
+ self.fc_feat_size = opt.fc_feat_size
193
+ self.att_feat_size = opt.att_feat_size
194
+ self.att_hid_size = opt.att_hid_size
195
+
196
+ self.rnn = getattr(nn, self.rnn_type.upper())(self.input_encoding_size + self.att_feat_size,
197
+ self.rnn_size, self.num_layers, bias=False,
198
+ dropout=self.drop_prob_lm)
199
+
200
+ if self.att_hid_size > 0:
201
+ self.ctx2att = nn.Linear(self.att_feat_size, self.att_hid_size)
202
+ self.h2att = nn.Linear(self.rnn_size, self.att_hid_size)
203
+ self.alpha_net = nn.Linear(self.att_hid_size, 1)
204
+ else:
205
+ self.ctx2att = nn.Linear(self.att_feat_size, 1)
206
+ self.h2att = nn.Linear(self.rnn_size, 1)
207
+
208
+ def forward(self, xt, fc_feats, att_feats, state):
209
+ att_size = att_feats.numel() // att_feats.size(0) // self.att_feat_size
210
+ att = att_feats.view(-1, self.att_feat_size)
211
+ if self.att_hid_size > 0:
212
+ att = self.ctx2att(att) # (batch * att_size) * att_hid_size
213
+ att = att.view(-1, att_size, self.att_hid_size) # batch * att_size * att_hid_size
214
+ att_h = self.h2att(state[0][-1]) # batch * att_hid_size
215
+ att_h = att_h.unsqueeze(1).expand_as(att) # batch * att_size * att_hid_size
216
+ dot = att + att_h # batch * att_size * att_hid_size
217
+ dot = torch.tanh(dot) # batch * att_size * att_hid_size
218
+ dot = dot.view(-1, self.att_hid_size) # (batch * att_size) * att_hid_size
219
+ dot = self.alpha_net(dot) # (batch * att_size) * 1
220
+ dot = dot.view(-1, att_size) # batch * att_size
221
+ else:
222
+ att = self.ctx2att(att)(att) # (batch * att_size) * 1
223
+ att = att.view(-1, att_size) # batch * att_size
224
+ att_h = self.h2att(state[0][-1]) # batch * 1
225
+ att_h = att_h.expand_as(att) # batch * att_size
226
+ dot = att_h + att # batch * att_size
227
+
228
+ weight = F.softmax(dot, -1)
229
+ att_feats_ = att_feats.view(-1, att_size, self.att_feat_size) # batch * att_size * att_feat_size
230
+ att_res = torch.bmm(weight.unsqueeze(1), att_feats_).squeeze(1) # batch * att_feat_size
231
+
232
+ output, state = self.rnn(torch.cat([xt, att_res], 1).unsqueeze(0), state)
233
+ return output.squeeze(0), state
234
+
235
+
236
+ class AllImgCore(nn.Module):
237
+ def __init__(self, opt):
238
+ super(AllImgCore, self).__init__()
239
+ self.input_encoding_size = opt.input_encoding_size
240
+ self.rnn_type = opt.rnn_type
241
+ self.rnn_size = opt.rnn_size
242
+ self.num_layers = opt.num_layers
243
+ self.drop_prob_lm = opt.drop_prob_lm
244
+ self.fc_feat_size = opt.fc_feat_size
245
+
246
+ self.rnn = getattr(nn, self.rnn_type.upper())(self.input_encoding_size + self.fc_feat_size,
247
+ self.rnn_size, self.num_layers, bias=False,
248
+ dropout=self.drop_prob_lm)
249
+
250
+ def forward(self, xt, fc_feats, att_feats, state):
251
+ output, state = self.rnn(torch.cat([xt, fc_feats], 1).unsqueeze(0), state)
252
+ return output.squeeze(0), state
253
+
254
+
255
+ class ShowAttendTellModel(OldModel):
256
+ def __init__(self, opt):
257
+ super(ShowAttendTellModel, self).__init__(opt)
258
+ self.core = ShowAttendTellCore(opt)
259
+
260
+
261
+ class AllImgModel(OldModel):
262
+ def __init__(self, opt):
263
+ super(AllImgModel, self).__init__(opt)
264
+ self.core = AllImgCore(opt)
265
+
captioning/models/ShowTellModel.py ADDED
@@ -0,0 +1,368 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import absolute_import
2
+ from __future__ import division
3
+ from __future__ import print_function
4
+
5
+ import numpy as np
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+ from torch.autograd import *
10
+ from . import utils
11
+
12
+ from .CaptionModel import CaptionModel
13
+
14
+ bad_endings = ['a','an','the','in','for','at','of','with','before','after','on','upon','near','to','is','are','am']
15
+ bad_endings += ['UNK', 'has', 'and', 'more']
16
+
17
+ # torch.manual_seed(42)
18
+ # if torch.cuda.is_available():
19
+ # torch.cuda.manual_seed(42)
20
+
21
+ class ShowTellModel(CaptionModel):
22
+ def __init__(self, opt):
23
+ super(ShowTellModel, self).__init__()
24
+ self.vocab_size = opt.vocab_size
25
+ self.input_encoding_size = opt.input_encoding_size
26
+ self.rnn_type = opt.rnn_type
27
+ self.rnn_size = opt.rnn_size
28
+ self.num_layers = opt.num_layers
29
+ self.drop_prob_lm = opt.drop_prob_lm
30
+ self.seq_length = opt.seq_length
31
+ self.fc_feat_size = opt.fc_feat_size
32
+
33
+ self.eos_idx = getattr(opt, 'eos_idx', 0)
34
+ self.pad_idx = getattr(opt, 'pad_idx', 0)
35
+
36
+ self.ss_prob = 0.0 # Schedule sampling probability
37
+
38
+ self.img_embed = nn.Linear(self.fc_feat_size, self.input_encoding_size)
39
+ self.core = getattr(nn, self.rnn_type.upper())(self.input_encoding_size, self.rnn_size, self.num_layers, bias=False, dropout=self.drop_prob_lm)
40
+ self.embed = nn.Embedding(self.vocab_size + 1, self.input_encoding_size)
41
+ self.logit = nn.Linear(self.rnn_size, self.vocab_size + 1)
42
+ self.dropout = nn.Dropout(self.drop_prob_lm)
43
+
44
+ # For remove bad endding
45
+ self.vocab = opt.vocab
46
+ self.bad_endings_ix = [int(k) for k,v in self.vocab.items() if v in bad_endings]
47
+
48
+ self.init_weights()
49
+
50
+ def init_weights(self):
51
+ initrange = 0.1
52
+ self.embed.weight.data.uniform_(-initrange, initrange)
53
+ self.logit.bias.data.fill_(0)
54
+ self.logit.weight.data.uniform_(-initrange, initrange)
55
+
56
+ def init_hidden(self, bsz):
57
+ weight = self.logit.weight
58
+ if self.rnn_type == 'lstm':
59
+ return (weight.new_zeros(self.num_layers, bsz, self.rnn_size),
60
+ weight.new_zeros(self.num_layers, bsz, self.rnn_size))
61
+ else:
62
+ return weight.new_zeros(self.num_layers, bsz, self.rnn_size)
63
+
64
+ def _forward(self, fc_feats, att_feats, seq, att_masks=None):
65
+
66
+ batch_size = fc_feats.size(0)
67
+ if seq.ndim == 3: # B * seq_per_img * seq_len
68
+ seq = seq.reshape(-1, seq.shape[2])
69
+ seq_per_img = seq.shape[0] // batch_size
70
+ state = self.init_hidden(batch_size*seq_per_img)
71
+ outputs = []
72
+
73
+ if seq_per_img > 1:
74
+ fc_feats = utils.repeat_tensors(seq_per_img, fc_feats)
75
+
76
+ for i in range(seq.size(1)+1):
77
+ if i == 0:
78
+ xt = self.img_embed(fc_feats)
79
+ else:
80
+ if self.training and i >= 2 and self.ss_prob > 0.0: # otherwiste no need to sample
81
+ sample_prob = fc_feats.data.new(batch_size*seq_per_img).uniform_(0, 1)
82
+ sample_mask = sample_prob < self.ss_prob
83
+ if sample_mask.sum() == 0:
84
+ it = seq[:, i-1].clone()
85
+ else:
86
+ sample_ind = sample_mask.nonzero().view(-1)
87
+ it = seq[:, i-1].data.clone()
88
+ #prob_prev = torch.exp(outputs[-1].data.index_select(0, sample_ind)) # fetch prev distribution: shape Nx(M+1)
89
+ #it.index_copy_(0, sample_ind, torch.multinomial(prob_prev, 1).view(-1))
90
+ prob_prev = torch.exp(outputs[-1].data) # fetch prev distribution: shape Nx(M+1)
91
+ it.index_copy_(0, sample_ind, torch.multinomial(prob_prev, 1).view(-1).index_select(0, sample_ind))
92
+ else:
93
+ it = seq[:, i-1].clone()
94
+ # break if all the sequences end
95
+ if i >= 2 and seq[:, i-1].data.sum() == 0:
96
+ break
97
+ xt = self.embed(it)
98
+
99
+ output, state = self.core(xt.unsqueeze(0), state)
100
+
101
+ output = F.log_softmax(self.logit(self.dropout(output.squeeze(0))), dim=1)
102
+ outputs.append(output)
103
+
104
+ return torch.cat([_.unsqueeze(1) for _ in outputs[1:]], 1).contiguous()
105
+
106
+ def get_logprobs_state(self, it, state):
107
+ # 'it' contains a word index
108
+ xt = self.embed(it)
109
+
110
+ output, state = self.core(xt.unsqueeze(0), state)
111
+ logprobs = F.log_softmax(self.logit(self.dropout(output.squeeze(0))), dim=1)
112
+
113
+ return logprobs, state
114
+
115
+ def _sample_beam(self, fc_feats, att_feats, att_masks=None, opt={}):
116
+ # beam_size = opt.get('beam_size', 10)
117
+ # batch_size = fc_feats.size(0)
118
+
119
+ # assert beam_size <= self.vocab_size + 1, 'lets assume this for now, otherwise this corner case causes a few headaches down the road. can be dealt with in future if needed'
120
+ # seq = torch.LongTensor(self.seq_length, batch_size).zero_()
121
+ # seqLogprobs = torch.FloatTensor(self.seq_length, batch_size)
122
+ # # lets process every image independently for now, for simplicity
123
+
124
+
125
+ beam_size = opt.get('beam_size', 10)
126
+ group_size = opt.get('group_size', 1)
127
+ sample_n = opt.get('sample_n', 10)
128
+ # when sample_n == beam_size then each beam is a sample.
129
+ assert sample_n == 1 or sample_n == beam_size // group_size, 'when beam search, sample_n == 1 or beam search'
130
+ batch_size = fc_feats.size(0)
131
+
132
+ assert beam_size <= self.vocab_size + 1, 'lets assume this for now, otherwise this corner case causes a few headaches down the road. can be dealt with in future if needed'
133
+ seq = fc_feats.new_full((batch_size*sample_n, self.seq_length), self.pad_idx, dtype=torch.long)
134
+ seqLogprobs = fc_feats.new_zeros(batch_size*sample_n, self.seq_length, self.vocab_size + 1)
135
+
136
+ self.done_beams = [[] for _ in range(batch_size)]
137
+ for k in range(batch_size):
138
+ state = self.init_hidden(beam_size)
139
+ for t in range(2):
140
+ if t == 0:
141
+ xt = self.img_embed(fc_feats[k:k+1]).expand(beam_size, self.input_encoding_size)
142
+ elif t == 1: # input <bos>
143
+ it = fc_feats.data.new(beam_size).long().zero_()
144
+ xt = self.embed(it)
145
+
146
+ output, state = self.core(xt.unsqueeze(0), state)
147
+ logprobs = F.log_softmax(self.logit(self.dropout(output.squeeze(0))), dim=1)
148
+
149
+ self.done_beams[k] = self.old_beam_search(state, logprobs, opt=opt)
150
+ if sample_n == beam_size:
151
+ for _n in range(sample_n):
152
+ seq[k*sample_n+_n, :] = self.done_beams[k][_n]['seq']
153
+ seqLogprobs[k*sample_n+_n, :] = self.done_beams[k][_n]['logps']
154
+ else:
155
+ seq[k, :] = self.done_beams[k][0]['seq'] # the first beam has highest cumulative score
156
+ seqLogprobs[k, :] = self.done_beams[k][0]['logps']
157
+ # return the samples and their log likelihoods
158
+ return seq, seqLogprobs
159
+
160
+ # seq[:, k] = self.done_beams[k][0]['seq'] # the first beam has highest cumulative score
161
+ # seqLogprobs[:, k] = self.done_beams[k][0]['logps']
162
+ # # return the samples and their log likelihoods
163
+ # return seq.transpose(0, 1), seqLogprobs.transpose(0, 1)
164
+
165
+ def _new_sample_beam(self, fc_feats, att_feats, att_masks=None, opt={}):
166
+
167
+ beam_size = opt.get('beam_size', 10)
168
+ group_size = opt.get('group_size', 1)
169
+ sample_n = opt.get('sample_n', 10)
170
+ # when sample_n == beam_size then each beam is a sample.
171
+ assert sample_n == 1 or sample_n == beam_size // group_size, 'when beam search, sample_n == 1 or beam search'
172
+ batch_size = fc_feats.size(0)
173
+
174
+ assert beam_size <= self.vocab_size + 1, 'lets assume this for now, otherwise this corner case causes a few headaches down the road. can be dealt with in future if needed'
175
+ seq = fc_feats.new_full((batch_size*sample_n, self.seq_length), self.pad_idx, dtype=torch.long)
176
+ seqLogprobs = fc_feats.new_zeros(batch_size*sample_n, self.seq_length, self.vocab_size + 1)
177
+
178
+ self.done_beams = [[] for _ in range(batch_size)]
179
+
180
+ state = self.init_hidden(batch_size)
181
+
182
+ it = fc_feats.data.new(batch_size).long().zero_()
183
+ xt = self.embed(it)
184
+
185
+ output, state = self.core(xt.unsqueeze(0), state)
186
+ logprobs = F.log_softmax(self.logit(self.dropout(output.squeeze(0))), dim=1)
187
+
188
+ self.done_beams = self.beam_search(state, logprobs, opt=opt)
189
+
190
+ for k in range(batch_size):
191
+ if sample_n == beam_size:
192
+ for _n in range(sample_n):
193
+ seq_len = self.done_beams[k][_n]['seq'].shape[0]
194
+ seq[k*sample_n+_n, :seq_len] = self.done_beams[k][_n]['seq']
195
+ seqLogprobs[k*sample_n+_n, :seq_len] = self.done_beams[k][_n]['logps']
196
+ else:
197
+ seq_len = self.done_beams[k][0]['seq'].shape[0]
198
+ seq[k, :seq_len] = self.done_beams[k][0]['seq'] # the first beam has highest cumulative score
199
+ seqLogprobs[k, :seq_len] = self.done_beams[k][0]['logps']
200
+ # return the samples and their log likelihoods
201
+ return seq, seqLogprobs
202
+
203
+ def _old_sample(self, fc_feats, att_feats, att_masks=None, opt={}):
204
+ sample_method = opt.get('sample_method', 'greedy')
205
+ beam_size = opt.get('beam_size', 1)
206
+ temperature = opt.get('temperature', 1.0)
207
+ if beam_size > 1 and sample_method in ['greedy', 'beam_search']:
208
+ return self._sample_beam(fc_feats, att_feats, opt)
209
+
210
+ batch_size = fc_feats.size(0)
211
+ state = self.init_hidden(batch_size)
212
+ seq = fc_feats.new_zeros(batch_size, self.seq_length, dtype=torch.long)
213
+ seqLogprobs = fc_feats.new_zeros(batch_size, self.seq_length)
214
+ for t in range(self.seq_length + 2):
215
+ if t == 0:
216
+ xt = self.img_embed(fc_feats)
217
+ else:
218
+ if t == 1: # input <bos>
219
+ it = fc_feats.data.new(batch_size).long().zero_()
220
+ xt = self.embed(it)
221
+
222
+ output, state = self.core(xt.unsqueeze(0), state)
223
+ logprobs = F.log_softmax(self.logit(self.dropout(output.squeeze(0))), dim=1)
224
+
225
+ # sample the next word
226
+ if t == self.seq_length + 1: # skip if we achieve maximum length
227
+ break
228
+ if sample_method == 'greedy':
229
+ sampleLogprobs, it = torch.max(logprobs.data, 1)
230
+ it = it.view(-1).long()
231
+ else:
232
+ if temperature == 1.0:
233
+ prob_prev = torch.exp(logprobs.data).cpu() # fetch prev distribution: shape Nx(M+1)
234
+ else:
235
+ # scale logprobs by temperature
236
+ prob_prev = torch.exp(torch.div(logprobs.data, temperature)).cpu()
237
+ it = torch.multinomial(prob_prev, 1).to(logprobs.device)
238
+ sampleLogprobs = logprobs.gather(1, it) # gather the logprobs at sampled positions
239
+ it = it.view(-1).long() # and flatten indices for downstream processing
240
+
241
+ if t >= 1:
242
+ # stop when all finished
243
+ if t == 1:
244
+ unfinished = it > 0
245
+ else:
246
+ unfinished = unfinished & (it > 0)
247
+ it = it * unfinished.type_as(it)
248
+ seq[:,t-1] = it #seq[t] the input of t+2 time step
249
+ seqLogprobs[:,t-1] = sampleLogprobs.view(-1)
250
+ if unfinished.sum() == 0:
251
+ break
252
+ return seq, seqLogprobs
253
+
254
+
255
+ # remove bad endings and UNK
256
+ def _sample(self, fc_feats, att_feats, att_masks=None, opt={}):
257
+ sample_method = opt.get('sample_method', 'greedy')
258
+ beam_size = opt.get('beam_size', 1)
259
+ temperature = opt.get('temperature', 1.0)
260
+
261
+ sample_n = int(opt.get('sample_n', 1))
262
+ sample_n = 1
263
+ group_size = opt.get('group_size', 1)
264
+ output_logsoftmax = opt.get('output_logsoftmax', 1)
265
+ decoding_constraint = opt.get('decoding_constraint', 0)
266
+ block_trigrams = opt.get('block_trigrams', 0)
267
+ remove_bad_endings = opt.get('remove_bad_endings', 1)
268
+ suppress_UNK = opt.get('suppress_UNK', 1)
269
+
270
+ if beam_size > 1 and sample_method in ['greedy', 'beam_search']:
271
+ return self._sample_beam(fc_feats, att_feats, opt=opt)
272
+
273
+ batch_size = fc_feats.size(0)
274
+ state = self.init_hidden(batch_size)
275
+
276
+ trigrams = [] # will be a list of batch_size dictionaries
277
+
278
+ # seq = fc_feats.new_zeros(batch_size, self.seq_length, dtype=torch.long)
279
+ # seqLogprobs = fc_feats.new_zeros(batch_size, self.seq_length)
280
+
281
+ seq = fc_feats.new_full((batch_size*sample_n, self.seq_length), self.pad_idx, dtype=torch.long)
282
+ seqLogprobs = fc_feats.new_zeros(batch_size*sample_n, self.seq_length, self.vocab_size + 1)
283
+ for t in range(self.seq_length + 1):
284
+ if t == 0:
285
+ xt = self.img_embed(fc_feats)
286
+ else:
287
+ if t == 1: # input <bos>
288
+ it = fc_feats.data.new(batch_size).long().zero_()
289
+ xt = self.embed(it)
290
+
291
+ output, state = self.core(xt.unsqueeze(0), state)
292
+ logprobs = F.log_softmax(self.logit(self.dropout(output.squeeze(0))), dim=1)
293
+
294
+ if decoding_constraint and t > 0:
295
+ tmp = logprobs.new_zeros(logprobs.size())
296
+ tmp.scatter_(1, seq[:,t-1].data.unsqueeze(1), float('-inf'))
297
+ logprobs = logprobs + tmp
298
+
299
+ # print('seq', seq)
300
+ # print('self.seq_length',self.seq_length)
301
+ # print('seq shape', seq.shape)
302
+ if remove_bad_endings and t > 0:
303
+ logprobs[torch.from_numpy(np.isin(seq[:,t-1].data.cpu().numpy(), self.bad_endings_ix)), 0] = float('-inf')
304
+
305
+ # suppress UNK tokens in the decoding
306
+ if suppress_UNK and hasattr(self, 'vocab') and self.vocab[str(logprobs.size(1)-1)] == 'UNK':
307
+ logprobs[:,logprobs.size(1)-1] = logprobs[:, logprobs.size(1)-1] - 1000
308
+
309
+ # if remove_bad_endings and t > 0:
310
+ # tmp = logprobs.new_zeros(logprobs.size())
311
+ # prev_bad = np.isin(seq[:,t-1].data.cpu().numpy(), self.bad_endings_ix)
312
+ # # Make it impossible to generate bad_endings
313
+ # tmp[torch.from_numpy(prev_bad.astype('uint8')), 0] = float('-inf')
314
+ # # tmp[torch.from_numpy(prev_bad.bool()), 0] = float('-inf')
315
+ # logprobs = logprobs + tmp
316
+
317
+ # Mess with trigrams
318
+ # Copy from https://github.com/lukemelas/image-paragraph-captioning
319
+ if block_trigrams and t >= 3:
320
+ # Store trigram generated at last step
321
+ prev_two_batch = seq[:,t-3:t-1]
322
+ for i in range(batch_size): # = seq.size(0)
323
+ prev_two = (prev_two_batch[i][0].item(), prev_two_batch[i][1].item())
324
+ current = seq[i][t-1]
325
+ if t == 3: # initialize
326
+ trigrams.append({prev_two: [current]}) # {LongTensor: list containing 1 int}
327
+ elif t > 3:
328
+ if prev_two in trigrams[i]: # add to list
329
+ trigrams[i][prev_two].append(current)
330
+ else: # create list
331
+ trigrams[i][prev_two] = [current]
332
+ # Block used trigrams at next step
333
+ prev_two_batch = seq[:,t-2:t]
334
+ mask = torch.zeros(logprobs.size(), requires_grad=False).to(logprobs.device) # batch_size x vocab_size
335
+ for i in range(batch_size):
336
+ prev_two = (prev_two_batch[i][0].item(), prev_two_batch[i][1].item())
337
+ if prev_two in trigrams[i]:
338
+ for j in trigrams[i][prev_two]:
339
+ mask[i,j] += 1
340
+ # Apply mask to log probs
341
+ #logprobs = logprobs - (mask * 1e9)
342
+ alpha = 2.0 # = 4
343
+ logprobs = logprobs + (mask * -0.693 * alpha) # ln(1/2) * alpha (alpha -> infty works best)
344
+
345
+ # sample the next word
346
+ if t == self.seq_length+1: # skip if we achieve maximum length
347
+ break
348
+ it, sampleLogprobs = self.sample_next_word(logprobs, sample_method, temperature)
349
+
350
+ # stop when all finished
351
+ if t == 0:
352
+ unfinished = it != self.eos_idx
353
+ else:
354
+ it[~unfinished] = self.pad_idx # This allows eos_idx not being overwritten to 0
355
+ logprobs = logprobs * unfinished.unsqueeze(1).to(logprobs)
356
+ unfinished = unfinished & (it != self.eos_idx)
357
+
358
+ # print('-------logprobs shape:',logprobs.shape)
359
+ # print('-------it shape:',it.shape)
360
+
361
+ seq[:,t-1] = it
362
+ seqLogprobs[:,t-1] = logprobs
363
+ # quit loop if all sequences have finished
364
+ if unfinished.sum() == 0:
365
+ break
366
+ # print('-------seqLogprobs shape:',seqLogprobs.shape)
367
+ # print('-------seq shape:',seq.shape)
368
+ return seq, seqLogprobs
captioning/models/TransformerModel.py ADDED
@@ -0,0 +1,367 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This file contains Transformer network
2
+ # Most of the code is copied from http://nlp.seas.harvard.edu/2018/04/03/attention.html
3
+
4
+ # The cfg name correspondance:
5
+ # N=num_layers
6
+ # d_model=input_encoding_size
7
+ # d_ff=rnn_size
8
+ # h is always 8
9
+
10
+ from __future__ import absolute_import
11
+ from __future__ import division
12
+ from __future__ import print_function
13
+
14
+ import torch
15
+ import torch.nn as nn
16
+ import torch.nn.functional as F
17
+ from . import utils
18
+
19
+ import copy
20
+ import math
21
+ import numpy as np
22
+
23
+ from .CaptionModel import CaptionModel
24
+ from .AttModel import sort_pack_padded_sequence, pad_unsort_packed_sequence, pack_wrapper, AttModel
25
+
26
+ # torch.manual_seed(42)
27
+ # if torch.cuda.is_available():
28
+ # torch.cuda.manual_seed(42)
29
+
30
+ class EncoderDecoder(nn.Module):
31
+ """
32
+ A standard Encoder-Decoder architecture. Base for this and many
33
+ other models.
34
+ """
35
+ def __init__(self, encoder, decoder, src_embed, tgt_embed, generator):
36
+ super(EncoderDecoder, self).__init__()
37
+ self.encoder = encoder
38
+ self.decoder = decoder
39
+ self.src_embed = src_embed
40
+ self.tgt_embed = tgt_embed
41
+ self.generator = generator
42
+
43
+ def forward(self, src, tgt, src_mask, tgt_mask):
44
+ "Take in and process masked src and target sequences."
45
+ return self.decode(self.encode(src, src_mask), src_mask,
46
+ tgt, tgt_mask)
47
+
48
+ def encode(self, src, src_mask):
49
+ return self.encoder(self.src_embed(src), src_mask)
50
+
51
+ def decode(self, memory, src_mask, tgt, tgt_mask):
52
+ return self.decoder(self.tgt_embed(tgt), memory, src_mask, tgt_mask)
53
+
54
+ class Generator(nn.Module):
55
+ "Define standard linear + softmax generation step."
56
+ def __init__(self, d_model, vocab):
57
+ super(Generator, self).__init__()
58
+ self.proj = nn.Linear(d_model, vocab)
59
+
60
+ def forward(self, x):
61
+ return F.log_softmax(self.proj(x), dim=-1)
62
+
63
+ def clones(module, N):
64
+ "Produce N identical layers."
65
+ return nn.ModuleList([copy.deepcopy(module) for _ in range(N)])
66
+
67
+ class Encoder(nn.Module):
68
+ "Core encoder is a stack of N layers"
69
+ def __init__(self, layer, N):
70
+ super(Encoder, self).__init__()
71
+ self.layers = clones(layer, N)
72
+ self.norm = LayerNorm(layer.size)
73
+
74
+ def forward(self, x, mask):
75
+ "Pass the input (and mask) through each layer in turn."
76
+ for layer in self.layers:
77
+ x = layer(x, mask)
78
+ return self.norm(x)
79
+
80
+ class LayerNorm(nn.Module):
81
+ "Construct a layernorm module (See citation for details)."
82
+ def __init__(self, features, eps=1e-6):
83
+ super(LayerNorm, self).__init__()
84
+ self.a_2 = nn.Parameter(torch.ones(features))
85
+ self.b_2 = nn.Parameter(torch.zeros(features))
86
+ self.eps = eps
87
+
88
+ def forward(self, x):
89
+ mean = x.mean(-1, keepdim=True)
90
+ std = x.std(-1, keepdim=True)
91
+ return self.a_2 * (x - mean) / (std + self.eps) + self.b_2
92
+
93
+ class SublayerConnection(nn.Module):
94
+ """
95
+ A residual connection followed by a layer norm.
96
+ Note for code simplicity the norm is first as opposed to last.
97
+ """
98
+ def __init__(self, size, dropout):
99
+ super(SublayerConnection, self).__init__()
100
+ self.norm = LayerNorm(size)
101
+ self.dropout = nn.Dropout(dropout)
102
+
103
+ def forward(self, x, sublayer):
104
+ "Apply residual connection to any sublayer with the same size."
105
+ return x + self.dropout(sublayer(self.norm(x)))
106
+
107
+ class EncoderLayer(nn.Module):
108
+ "Encoder is made up of self-attn and feed forward (defined below)"
109
+ def __init__(self, size, self_attn, feed_forward, dropout):
110
+ super(EncoderLayer, self).__init__()
111
+ self.self_attn = self_attn
112
+ self.feed_forward = feed_forward
113
+ self.sublayer = clones(SublayerConnection(size, dropout), 2)
114
+ self.size = size
115
+
116
+ def forward(self, x, mask):
117
+ "Follow Figure 1 (left) for connections."
118
+ x = self.sublayer[0](x, lambda x: self.self_attn(x, x, x, mask))
119
+ return self.sublayer[1](x, self.feed_forward)
120
+
121
+ class Decoder(nn.Module):
122
+ "Generic N layer decoder with masking."
123
+ def __init__(self, layer, N):
124
+ super(Decoder, self).__init__()
125
+ self.layers = clones(layer, N)
126
+ self.norm = LayerNorm(layer.size)
127
+
128
+ def forward(self, x, memory, src_mask, tgt_mask):
129
+ for layer in self.layers:
130
+ x = layer(x, memory, src_mask, tgt_mask)
131
+ return self.norm(x)
132
+
133
+ class DecoderLayer(nn.Module):
134
+ "Decoder is made of self-attn, src-attn, and feed forward (defined below)"
135
+ def __init__(self, size, self_attn, src_attn, feed_forward, dropout):
136
+ super(DecoderLayer, self).__init__()
137
+ self.size = size
138
+ self.self_attn = self_attn
139
+ self.src_attn = src_attn
140
+ self.feed_forward = feed_forward
141
+ self.sublayer = clones(SublayerConnection(size, dropout), 3)
142
+
143
+ def forward(self, x, memory, src_mask, tgt_mask):
144
+ "Follow Figure 1 (right) for connections."
145
+ m = memory
146
+ x = self.sublayer[0](x, lambda x: self.self_attn(x, x, x, tgt_mask))
147
+ x = self.sublayer[1](x, lambda x: self.src_attn(x, m, m, src_mask))
148
+ return self.sublayer[2](x, self.feed_forward)
149
+
150
+ def subsequent_mask(size):
151
+ "Mask out subsequent positions."
152
+ attn_shape = (1, size, size)
153
+ subsequent_mask = np.triu(np.ones(attn_shape), k=1).astype('uint8')
154
+ return torch.from_numpy(subsequent_mask) == 0
155
+
156
+ def attention(query, key, value, mask=None, dropout=None):
157
+ "Compute 'Scaled Dot Product Attention'"
158
+ d_k = query.size(-1)
159
+ scores = torch.matmul(query, key.transpose(-2, -1)) \
160
+ / math.sqrt(d_k)
161
+ if mask is not None:
162
+ scores = scores.masked_fill(mask == 0, float('-inf'))
163
+ p_attn = F.softmax(scores, dim = -1)
164
+ if dropout is not None:
165
+ p_attn = dropout(p_attn)
166
+ return torch.matmul(p_attn, value), p_attn
167
+
168
+ class MultiHeadedAttention(nn.Module):
169
+ def __init__(self, h, d_model, dropout=0.1):
170
+ "Take in model size and number of heads."
171
+ super(MultiHeadedAttention, self).__init__()
172
+ assert d_model % h == 0
173
+ # We assume d_v always equals d_k
174
+ self.d_k = d_model // h
175
+ self.h = h
176
+ self.linears = clones(nn.Linear(d_model, d_model), 4)
177
+ self.attn = None
178
+ self.dropout = nn.Dropout(p=dropout)
179
+
180
+ def forward(self, query, key, value, mask=None):
181
+ "Implements Figure 2"
182
+ if mask is not None:
183
+ # Same mask applied to all h heads.
184
+ mask = mask.unsqueeze(1)
185
+ nbatches = query.size(0)
186
+
187
+ # 1) Do all the linear projections in batch from d_model => h x d_k
188
+ query, key, value = \
189
+ [l(x).view(nbatches, -1, self.h, self.d_k).transpose(1, 2)
190
+ for l, x in zip(self.linears, (query, key, value))]
191
+
192
+ # 2) Apply attention on all the projected vectors in batch.
193
+ x, self.attn = attention(query, key, value, mask=mask,
194
+ dropout=self.dropout)
195
+
196
+ # 3) "Concat" using a view and apply a final linear.
197
+ x = x.transpose(1, 2).contiguous() \
198
+ .view(nbatches, -1, self.h * self.d_k)
199
+ return self.linears[-1](x)
200
+
201
+ class PositionwiseFeedForward(nn.Module):
202
+ "Implements FFN equation."
203
+ def __init__(self, d_model, d_ff, dropout=0.1):
204
+ super(PositionwiseFeedForward, self).__init__()
205
+ self.w_1 = nn.Linear(d_model, d_ff)
206
+ self.w_2 = nn.Linear(d_ff, d_model)
207
+ self.dropout = nn.Dropout(dropout)
208
+
209
+ def forward(self, x):
210
+ return self.w_2(self.dropout(F.relu(self.w_1(x))))
211
+
212
+ class Embeddings(nn.Module):
213
+ def __init__(self, d_model, vocab):
214
+ super(Embeddings, self).__init__()
215
+ self.lut = nn.Embedding(vocab, d_model)
216
+ self.d_model = d_model
217
+
218
+ def forward(self, x):
219
+ return self.lut(x) * math.sqrt(self.d_model)
220
+
221
+ class PositionalEncoding(nn.Module):
222
+ "Implement the PE function."
223
+ def __init__(self, d_model, dropout, max_len=5000):
224
+ super(PositionalEncoding, self).__init__()
225
+ self.dropout = nn.Dropout(p=dropout)
226
+
227
+ # Compute the positional encodings once in log space.
228
+ pe = torch.zeros(max_len, d_model)
229
+ position = torch.arange(0, max_len).unsqueeze(1).float()
230
+ div_term = torch.exp(torch.arange(0, d_model, 2).float() *
231
+ -(math.log(10000.0) / d_model))
232
+ pe[:, 0::2] = torch.sin(position * div_term)
233
+ pe[:, 1::2] = torch.cos(position * div_term)
234
+ pe = pe.unsqueeze(0)
235
+ self.register_buffer('pe', pe)
236
+
237
+ def forward(self, x):
238
+ x = x + self.pe[:, :x.size(1)]
239
+ return self.dropout(x)
240
+
241
+ class TransformerModel(AttModel):
242
+
243
+ def make_model(self, src_vocab, tgt_vocab, N_enc=6, N_dec=6,
244
+ d_model=512, d_ff=2048, h=8, dropout=0.1):
245
+ "Helper: Construct a model from hyperparameters."
246
+ c = copy.deepcopy
247
+ attn = MultiHeadedAttention(h, d_model, dropout)
248
+ ff = PositionwiseFeedForward(d_model, d_ff, dropout)
249
+ position = PositionalEncoding(d_model, dropout)
250
+ model = EncoderDecoder(
251
+ Encoder(EncoderLayer(d_model, c(attn), c(ff), dropout), N_enc),
252
+ Decoder(DecoderLayer(d_model, c(attn), c(attn),
253
+ c(ff), dropout), N_dec),
254
+ lambda x:x, # nn.Sequential(Embeddings(d_model, src_vocab), c(position)),
255
+ nn.Sequential(Embeddings(d_model, tgt_vocab), c(position)),
256
+ Generator(d_model, tgt_vocab))
257
+
258
+ # This was important from their code.
259
+ # Initialize parameters with Glorot / fan_avg.
260
+ for p in model.parameters():
261
+ if p.dim() > 1:
262
+ nn.init.xavier_uniform_(p)
263
+ return model
264
+
265
+ def __init__(self, opt):
266
+ super(TransformerModel, self).__init__(opt)
267
+ self.opt = opt
268
+ # self.config = yaml.load(open(opt.config_file))
269
+
270
+ self.N_enc = getattr(opt, 'N_enc', opt.num_layers)
271
+ self.N_dec = getattr(opt, 'N_dec', opt.num_layers)
272
+ self.d_model = getattr(opt, 'd_model', opt.input_encoding_size)
273
+ self.d_ff = getattr(opt, 'd_ff', opt.rnn_size)
274
+ self.h = getattr(opt, 'num_att_heads', 8)
275
+ self.dropout = getattr(opt, 'dropout', 0.1)
276
+
277
+ delattr(self, 'att_embed')
278
+ self.att_embed = nn.Sequential(*(
279
+ ((nn.BatchNorm1d(self.att_feat_size),) if self.use_bn else ())+
280
+ (nn.Linear(self.att_feat_size, self.d_model),
281
+ nn.ReLU(),
282
+ nn.Dropout(self.drop_prob_lm))+
283
+ ((nn.BatchNorm1d(self.d_model),) if self.use_bn==2 else ())))
284
+
285
+ delattr(self, 'embed')
286
+ self.embed = lambda x : x
287
+ delattr(self, 'fc_embed')
288
+ self.fc_embed = lambda x : x
289
+ delattr(self, 'logit')
290
+ del self.ctx2att
291
+
292
+ tgt_vocab = self.vocab_size + 1
293
+
294
+
295
+ self.model = self.make_model(0, tgt_vocab,
296
+ N_enc=self.N_enc,
297
+ N_dec=self.N_dec,
298
+ d_model=self.d_model,
299
+ d_ff=self.d_ff,
300
+ h=self.h,
301
+ dropout=self.dropout)
302
+
303
+ def logit(self, x): # unsafe way
304
+ return self.model.generator.proj(x)
305
+
306
+ def init_hidden(self, bsz):
307
+ return []
308
+
309
+ def _prepare_feature(self, fc_feats, att_feats, att_masks):
310
+
311
+ att_feats, seq, att_masks, seq_mask = self._prepare_feature_forward(att_feats, att_masks)
312
+ memory = self.model.encode(att_feats, att_masks)
313
+
314
+ return fc_feats[...,:0], att_feats[...,:0], memory, att_masks
315
+
316
+ def _prepare_feature_forward(self, att_feats, att_masks=None, seq=None):
317
+ att_feats, att_masks = self.clip_att(att_feats, att_masks)
318
+
319
+ att_feats = pack_wrapper(self.att_embed, att_feats, att_masks)
320
+
321
+ if att_masks is None:
322
+ att_masks = att_feats.new_ones(att_feats.shape[:2], dtype=torch.long)
323
+ att_masks = att_masks.unsqueeze(-2)
324
+
325
+ if seq is not None:
326
+ # crop the last one
327
+ # seq = seq[:,:-1]
328
+ seq_mask = (seq.data != self.eos_idx) & (seq.data != self.pad_idx)
329
+ seq_mask[:,0] = 1 # bos
330
+
331
+ seq_mask = seq_mask.unsqueeze(-2)
332
+ seq_mask = seq_mask & subsequent_mask(seq.size(-1)).to(seq_mask)
333
+
334
+ seq_per_img = seq.shape[0] // att_feats.shape[0]
335
+ if seq_per_img > 1:
336
+ att_feats, att_masks = utils.repeat_tensors(seq_per_img,
337
+ [att_feats, att_masks]
338
+ )
339
+ else:
340
+ seq_mask = None
341
+
342
+ return att_feats, seq, att_masks, seq_mask
343
+
344
+ def _forward(self, fc_feats, att_feats, seq, att_masks=None):
345
+ if seq.ndim == 3: # B * seq_per_img * seq_len
346
+ seq = seq.reshape(-1, seq.shape[2])
347
+ att_feats, seq, att_masks, seq_mask = self._prepare_feature_forward(att_feats, att_masks, seq)
348
+
349
+ out = self.model(att_feats, seq, att_masks, seq_mask)
350
+
351
+ outputs = self.model.generator(out)
352
+ return outputs
353
+ # return torch.cat([_.unsqueeze(1) for _ in outputs], 1)
354
+
355
+ def core(self, it, fc_feats_ph, att_feats_ph, memory, state, mask):
356
+ """
357
+ state = [ys.unsqueeze(0)]
358
+ """
359
+ if len(state) == 0:
360
+ ys = it.unsqueeze(1)
361
+ else:
362
+ ys = torch.cat([state[0][0], it.unsqueeze(1)], dim=1)
363
+ out = self.model.decode(memory, mask,
364
+ ys,
365
+ subsequent_mask(ys.size(1))
366
+ .to(memory.device))
367
+ return out[:, -1], [ys.unsqueeze(0)]
captioning/models/__init__.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import absolute_import
2
+ from __future__ import division
3
+ from __future__ import print_function
4
+
5
+ import os
6
+ import copy
7
+
8
+ import numpy as np
9
+ import torch
10
+
11
+ from .ShowTellModel import ShowTellModel
12
+ from .FCModel import FCModel
13
+ from .AttModel import *
14
+ from .TransformerModel import TransformerModel
15
+ from .cachedTransformer import TransformerModel as cachedTransformer
16
+ from .BertCapModel import BertCapModel
17
+ from .M2Transformer import M2TransformerModel
18
+ from .AoAModel import AoAModel
19
+ from .OldModel import ShowAttendTellModel
20
+
21
+ def setup(opt):
22
+ if opt.caption_model in ['fc', 'show_tell']:
23
+ print('Warning: %s model is mostly deprecated; many new features are not supported.' %opt.caption_model)
24
+ if opt.caption_model == 'fc':
25
+ print('Use newfc instead of fc')
26
+ if opt.caption_model == 'fc':
27
+ model = FCModel(opt)
28
+ elif opt.caption_model == 'language_model':
29
+ model = LMModel(opt)
30
+ elif opt.caption_model == 'newfc':
31
+ model = NewFCModel(opt)
32
+ elif opt.caption_model == 'show_tell':
33
+ model = ShowTellModel(opt)
34
+ elif opt.caption_model == 'show_attend_tell':
35
+ model = ShowAttendTellModel(opt)
36
+ # Att2in model in self-critical
37
+ elif opt.caption_model == 'att2in':
38
+ model = Att2inModel(opt)
39
+ # Att2in model with two-layer MLP img embedding and word embedding
40
+ elif opt.caption_model == 'att2in2':
41
+ model = Att2in2Model(opt)
42
+ elif opt.caption_model == 'att2all2':
43
+ print('Warning: this is not a correct implementation of the att2all model in the original paper.')
44
+ model = Att2all2Model(opt)
45
+ # Adaptive Attention model from Knowing when to look
46
+ elif opt.caption_model == 'adaatt':
47
+ model = AdaAttModel(opt)
48
+ # Adaptive Attention with maxout lstm
49
+ elif opt.caption_model == 'adaattmo':
50
+ model = AdaAttMOModel(opt)
51
+ # Top-down attention model
52
+ elif opt.caption_model in ['topdown', 'updown']:
53
+ model = UpDownModel(opt)
54
+ # StackAtt
55
+ elif opt.caption_model == 'stackatt':
56
+ model = StackAttModel(opt)
57
+ # DenseAtt
58
+ elif opt.caption_model == 'denseatt':
59
+ model = DenseAttModel(opt)
60
+ # Transformer
61
+ elif opt.caption_model == 'transformer':
62
+ if getattr(opt, 'cached_transformer', False):
63
+ model = cachedTransformer(opt)
64
+ else:
65
+ model = TransformerModel(opt)
66
+ # AoANet
67
+ elif opt.caption_model == 'aoa':
68
+ model = AoAModel(opt)
69
+ elif opt.caption_model == 'bert':
70
+ model = BertCapModel(opt)
71
+ elif opt.caption_model == 'm2transformer':
72
+ model = M2TransformerModel(opt)
73
+ else:
74
+ raise Exception("Caption model not supported: {}".format(opt.caption_model))
75
+
76
+ return model
captioning/models/cachedTransformer.py ADDED
@@ -0,0 +1,420 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This file contains Transformer network
2
+ # Most of the code is copied from http://nlp.seas.harvard.edu/2018/04/03/attention.html
3
+
4
+ # The cfg name correspondance:
5
+ # N=num_layers
6
+ # d_model=input_encoding_size
7
+ # d_ff=rnn_size
8
+ # h is always 8
9
+
10
+ from __future__ import absolute_import
11
+ from __future__ import division
12
+ from __future__ import print_function
13
+
14
+ import torch
15
+ import torch.nn as nn
16
+ import torch.nn.functional as F
17
+ from . import utils
18
+
19
+ import copy
20
+ import math
21
+ import numpy as np
22
+
23
+ from .CaptionModel import CaptionModel
24
+ from .AttModel import sort_pack_padded_sequence, pad_unsort_packed_sequence, pack_wrapper, AttModel
25
+
26
+ class EncoderDecoder(nn.Module):
27
+ """
28
+ A standard Encoder-Decoder architecture. Base for this and many
29
+ other models.
30
+ """
31
+ def __init__(self, encoder, decoder, src_embed, tgt_embed, generator):
32
+ super(EncoderDecoder, self).__init__()
33
+ self.encoder = encoder
34
+ self.decoder = decoder
35
+ self.src_embed = src_embed
36
+ self.tgt_embed = tgt_embed
37
+ self.generator = generator
38
+
39
+ def forward(self, src, tgt, src_mask, tgt_mask):
40
+ "Take in and process masked src and target sequences."
41
+ return self.decode(self.encode(src, src_mask), src_mask,
42
+ tgt, tgt_mask)
43
+
44
+ def encode(self, src, src_mask):
45
+ return self.encoder(self.src_embed(src), src_mask)
46
+
47
+ def decode(self, memory, src_mask, tgt, tgt_mask, past=None):
48
+ return self.decoder(self.tgt_embed(tgt), memory, src_mask, tgt_mask, past=past)
49
+
50
+ class Generator(nn.Module):
51
+ "Define standard linear + softmax generation step."
52
+ def __init__(self, d_model, vocab):
53
+ super(Generator, self).__init__()
54
+ self.proj = nn.Linear(d_model, vocab)
55
+
56
+ def forward(self, x):
57
+ return F.log_softmax(self.proj(x), dim=-1)
58
+
59
+ def clones(module, N):
60
+ "Produce N identical layers."
61
+ return nn.ModuleList([copy.deepcopy(module) for _ in range(N)])
62
+
63
+ class Encoder(nn.Module):
64
+ "Core encoder is a stack of N layers"
65
+ def __init__(self, layer, N):
66
+ super(Encoder, self).__init__()
67
+ self.layers = clones(layer, N)
68
+ self.norm = LayerNorm(layer.size)
69
+
70
+ def forward(self, x, mask):
71
+ "Pass the input (and mask) through each layer in turn."
72
+ for layer in self.layers:
73
+ x = layer(x, mask)
74
+ return self.norm(x)
75
+
76
+ class LayerNorm(nn.Module):
77
+ "Construct a layernorm module (See citation for details)."
78
+ def __init__(self, features, eps=1e-6):
79
+ super(LayerNorm, self).__init__()
80
+ self.a_2 = nn.Parameter(torch.ones(features))
81
+ self.b_2 = nn.Parameter(torch.zeros(features))
82
+ self.eps = eps
83
+
84
+ def forward(self, x):
85
+ mean = x.mean(-1, keepdim=True)
86
+ std = x.std(-1, keepdim=True)
87
+ return self.a_2 * (x - mean) / (std + self.eps) + self.b_2
88
+
89
+ class SublayerConnection(nn.Module):
90
+ """
91
+ A residual connection followed by a layer norm.
92
+ Note for code simplicity the norm is first as opposed to last.
93
+ """
94
+ def __init__(self, size, dropout):
95
+ super(SublayerConnection, self).__init__()
96
+ self.norm = LayerNorm(size)
97
+ self.dropout = nn.Dropout(dropout)
98
+
99
+ def forward(self, x, sublayer):
100
+ "Apply residual connection to any sublayer with the same size."
101
+ _x = sublayer(self.norm(x))
102
+ if type(_x) is tuple: # for multi-head attention that returns past
103
+ return x + self.dropout(_x[0]), _x[1]
104
+ return x + self.dropout(_x)
105
+
106
+ class EncoderLayer(nn.Module):
107
+ "Encoder is made up of self-attn and feed forward (defined below)"
108
+ def __init__(self, size, self_attn, feed_forward, dropout):
109
+ super(EncoderLayer, self).__init__()
110
+ self.self_attn = self_attn
111
+ self.feed_forward = feed_forward
112
+ self.sublayer = clones(SublayerConnection(size, dropout), 2)
113
+ self.size = size
114
+
115
+ def forward(self, x, mask):
116
+ "Follow Figure 1 (left) for connections."
117
+ x = self.sublayer[0](x, lambda x: self.self_attn(x, x, x, mask))
118
+ return self.sublayer[1](x, self.feed_forward)
119
+
120
+ class Decoder(nn.Module):
121
+ "Generic N layer decoder with masking."
122
+ def __init__(self, layer, N):
123
+ super(Decoder, self).__init__()
124
+ self.layers = clones(layer, N)
125
+ self.norm = LayerNorm(layer.size)
126
+
127
+ def forward(self, x, memory, src_mask, tgt_mask, past=None):
128
+ if past is not None:
129
+ present = [[], []]
130
+ x = x[:, -1:]
131
+ tgt_mask = tgt_mask[:, -1:] if tgt_mask is not None else None
132
+ past = list(zip(past[0].split(2, dim=0), past[1].split(2, dim=0)))
133
+ else:
134
+ past = [None] * len(self.layers)
135
+ for i, (layer, layer_past) in enumerate(zip(self.layers, past)):
136
+ x = layer(x, memory, src_mask, tgt_mask,
137
+ layer_past)
138
+ if layer_past is not None:
139
+ present[0].append(x[1][0])
140
+ present[1].append(x[1][1])
141
+ x = x[0]
142
+ if past[0] is None:
143
+ return self.norm(x)
144
+ else:
145
+ return self.norm(x), [torch.cat(present[0], 0), torch.cat(present[1], 0)]
146
+
147
+
148
+ class DecoderLayer(nn.Module):
149
+ "Decoder is made of self-attn, src-attn, and feed forward (defined below)"
150
+ def __init__(self, size, self_attn, src_attn, feed_forward, dropout):
151
+ super(DecoderLayer, self).__init__()
152
+ self.size = size
153
+ self.self_attn = self_attn
154
+ self.src_attn = src_attn
155
+ self.feed_forward = feed_forward
156
+ self.sublayer = clones(SublayerConnection(size, dropout), 3)
157
+
158
+ def forward(self, x, memory, src_mask, tgt_mask, layer_past=None):
159
+ "Follow Figure 1 (right) for connections."
160
+ m = memory
161
+ if layer_past is None:
162
+ x = self.sublayer[0](x, lambda x: self.self_attn(x, x, x, tgt_mask))
163
+ x = self.sublayer[1](x, lambda x: self.src_attn(x, m, m, src_mask))
164
+ return self.sublayer[2](x, self.feed_forward)
165
+ else:
166
+ present = [None, None]
167
+ x, present[0] = self.sublayer[0](x, lambda x: self.self_attn(x, x, x, tgt_mask, layer_past[0]))
168
+ x, present[1] = self.sublayer[1](x, lambda x: self.src_attn(x, m, m, src_mask, layer_past[1]))
169
+ return self.sublayer[2](x, self.feed_forward), present
170
+
171
+ def subsequent_mask(size):
172
+ "Mask out subsequent positions."
173
+ attn_shape = (1, size, size)
174
+ subsequent_mask = np.triu(np.ones(attn_shape), k=1).astype('uint8')
175
+ return torch.from_numpy(subsequent_mask) == 0
176
+
177
+ def attention(query, key, value, mask=None, dropout=None):
178
+ "Compute 'Scaled Dot Product Attention'"
179
+ d_k = query.size(-1)
180
+ scores = torch.matmul(query, key.transpose(-2, -1)) \
181
+ / math.sqrt(d_k)
182
+ if mask is not None:
183
+ scores = scores.masked_fill(mask == 0, float('-inf'))
184
+ p_attn = F.softmax(scores, dim = -1)
185
+ if dropout is not None:
186
+ p_attn = dropout(p_attn)
187
+ return torch.matmul(p_attn, value), p_attn
188
+
189
+ class MultiHeadedAttention(nn.Module):
190
+ def __init__(self, h, d_model, dropout=0.1):
191
+ "Take in model size and number of heads."
192
+ super(MultiHeadedAttention, self).__init__()
193
+ assert d_model % h == 0
194
+ # We assume d_v always equals d_k
195
+ self.d_k = d_model // h
196
+ self.h = h
197
+ self.linears = clones(nn.Linear(d_model, d_model), 4)
198
+ self.attn = None
199
+ self.dropout = nn.Dropout(p=dropout)
200
+
201
+ def forward(self, query, key, value, mask=None, layer_past=None):
202
+ "Implements Figure 2"
203
+ if mask is not None:
204
+ # Same mask applied to all h heads.
205
+ mask = mask.unsqueeze(1)
206
+ nbatches = query.size(0)
207
+
208
+ # The past works differently here. For self attn, the query and key be updated incrementailly
209
+ # For src_attn the past is fixed.
210
+
211
+ # For src_attn, when the layer past is ready
212
+ if layer_past is not None and layer_past.shape[2] == key.shape[1] > 1: # suppose memory size always greater than 1
213
+ query = self.linears[0](query)
214
+ key, value = layer_past[0], layer_past[1]
215
+ present = torch.stack([key, value])
216
+ else:
217
+ # 1) Do all the linear projections in batch from d_model => h x d_k
218
+ query, key, value = \
219
+ [l(x) for l, x in zip(self.linears, (query, key, value))]
220
+
221
+ # self attn + past OR the first time step of src attn
222
+ if layer_past is not None and not (layer_past.shape[2] == key.shape[1] > 1):
223
+ past_key, past_value = layer_past[0], layer_past[1]
224
+ key = torch.cat((past_key, key), dim=1)
225
+ value = torch.cat((past_value, value), dim=1)
226
+ present = torch.stack([key, value])
227
+
228
+ query, key, value = \
229
+ [x.view(nbatches, -1, self.h, self.d_k).transpose(1, 2)
230
+ for x in [query, key, value]]
231
+
232
+ # 2) Apply attention on all the projected vectors in batch.
233
+ x, self.attn = attention(query, key, value, mask=mask,
234
+ dropout=self.dropout)
235
+
236
+ # 3) "Concat" using a view and apply a final linear.
237
+ x = x.transpose(1, 2).contiguous() \
238
+ .view(nbatches, -1, self.h * self.d_k)
239
+ if layer_past is not None:
240
+ return self.linears[-1](x), present
241
+ else:
242
+ return self.linears[-1](x)
243
+
244
+ class PositionwiseFeedForward(nn.Module):
245
+ "Implements FFN equation."
246
+ def __init__(self, d_model, d_ff, dropout=0.1):
247
+ super(PositionwiseFeedForward, self).__init__()
248
+ self.w_1 = nn.Linear(d_model, d_ff)
249
+ self.w_2 = nn.Linear(d_ff, d_model)
250
+ self.dropout = nn.Dropout(dropout)
251
+
252
+ def forward(self, x):
253
+ return self.w_2(self.dropout(F.relu(self.w_1(x))))
254
+
255
+ class Embeddings(nn.Module):
256
+ def __init__(self, d_model, vocab):
257
+ super(Embeddings, self).__init__()
258
+ self.lut = nn.Embedding(vocab, d_model)
259
+ self.d_model = d_model
260
+
261
+ def forward(self, x):
262
+ return self.lut(x) * math.sqrt(self.d_model)
263
+
264
+ class PositionalEncoding(nn.Module):
265
+ "Implement the PE function."
266
+ def __init__(self, d_model, dropout, max_len=5000):
267
+ super(PositionalEncoding, self).__init__()
268
+ self.dropout = nn.Dropout(p=dropout)
269
+
270
+ # Compute the positional encodings once in log space.
271
+ pe = torch.zeros(max_len, d_model)
272
+ position = torch.arange(0, max_len).unsqueeze(1).float()
273
+ div_term = torch.exp(torch.arange(0, d_model, 2).float() *
274
+ -(math.log(10000.0) / d_model))
275
+ pe[:, 0::2] = torch.sin(position * div_term)
276
+ pe[:, 1::2] = torch.cos(position * div_term)
277
+ pe = pe.unsqueeze(0)
278
+ self.register_buffer('pe', pe)
279
+
280
+ def forward(self, x):
281
+ x = x + self.pe[:, :x.size(1)]
282
+ return self.dropout(x)
283
+
284
+ class TransformerModel(AttModel):
285
+
286
+ def make_model(self, src_vocab, tgt_vocab, N_enc=6, N_dec=6,
287
+ d_model=512, d_ff=2048, h=8, dropout=0.1):
288
+ "Helper: Construct a model from hyperparameters."
289
+ c = copy.deepcopy
290
+ attn = MultiHeadedAttention(h, d_model, dropout)
291
+ ff = PositionwiseFeedForward(d_model, d_ff, dropout)
292
+ position = PositionalEncoding(d_model, dropout)
293
+ model = EncoderDecoder(
294
+ Encoder(EncoderLayer(d_model, c(attn), c(ff), dropout), N_enc),
295
+ Decoder(DecoderLayer(d_model, c(attn), c(attn),
296
+ c(ff), dropout), N_dec),
297
+ lambda x:x, # nn.Sequential(Embeddings(d_model, src_vocab), c(position)),
298
+ nn.Sequential(Embeddings(d_model, tgt_vocab), c(position)),
299
+ Generator(d_model, tgt_vocab))
300
+
301
+ # This was important from their code.
302
+ # Initialize parameters with Glorot / fan_avg.
303
+ for p in model.parameters():
304
+ if p.dim() > 1:
305
+ nn.init.xavier_uniform_(p)
306
+ return model
307
+
308
+ def __init__(self, opt):
309
+ super(TransformerModel, self).__init__(opt)
310
+ self.opt = opt
311
+ # self.config = yaml.load(open(opt.config_file))
312
+
313
+ self.N_enc = getattr(opt, 'N_enc', opt.num_layers)
314
+ self.N_dec = getattr(opt, 'N_dec', opt.num_layers)
315
+ self.d_model = getattr(opt, 'd_model', opt.input_encoding_size)
316
+ self.d_ff = getattr(opt, 'd_ff', opt.rnn_size)
317
+ self.h = getattr(opt, 'num_att_heads', 8)
318
+ self.dropout = getattr(opt, 'dropout', 0.1)
319
+
320
+ delattr(self, 'att_embed')
321
+ self.att_embed = nn.Sequential(*(
322
+ ((nn.BatchNorm1d(self.att_feat_size),) if self.use_bn else ())+
323
+ (nn.Linear(self.att_feat_size, self.d_model),
324
+ nn.ReLU(),
325
+ nn.Dropout(self.drop_prob_lm))+
326
+ ((nn.BatchNorm1d(self.d_model),) if self.use_bn==2 else ())))
327
+
328
+ delattr(self, 'embed')
329
+ self.embed = lambda x : x
330
+ delattr(self, 'fc_embed')
331
+ self.fc_embed = lambda x : x
332
+ delattr(self, 'logit')
333
+ del self.ctx2att
334
+
335
+ tgt_vocab = self.vocab_size + 1
336
+
337
+
338
+ self.model = self.make_model(0, tgt_vocab,
339
+ N_enc=self.N_enc,
340
+ N_dec=self.N_dec,
341
+ d_model=self.d_model,
342
+ d_ff=self.d_ff,
343
+ h=self.h,
344
+ dropout=self.dropout)
345
+
346
+ def logit(self, x): # unsafe way
347
+ return self.model.generator.proj(x)
348
+
349
+ def init_hidden(self, bsz):
350
+ return []
351
+
352
+ def _prepare_feature(self, fc_feats, att_feats, att_masks):
353
+
354
+ att_feats, seq, att_masks, seq_mask = self._prepare_feature_forward(att_feats, att_masks)
355
+ memory = self.model.encode(att_feats, att_masks)
356
+
357
+ return fc_feats[...,:0], att_feats[...,:0], memory, att_masks
358
+
359
+ def _prepare_feature_forward(self, att_feats, att_masks=None, seq=None):
360
+ att_feats, att_masks = self.clip_att(att_feats, att_masks)
361
+
362
+ att_feats = pack_wrapper(self.att_embed, att_feats, att_masks)
363
+
364
+ if att_masks is None:
365
+ att_masks = att_feats.new_ones(att_feats.shape[:2], dtype=torch.long)
366
+ att_masks = att_masks.unsqueeze(-2)
367
+
368
+ if seq is not None:
369
+ # crop the last one
370
+ # seq = seq[:,:-1]
371
+ seq_mask = (seq.data != self.eos_idx) & (seq.data != self.pad_idx)
372
+ seq_mask[:,0] = 1 # bos
373
+
374
+ seq_mask = seq_mask.unsqueeze(-2)
375
+ seq_mask = seq_mask & subsequent_mask(seq.size(-1)).to(seq_mask)
376
+
377
+ seq_per_img = seq.shape[0] // att_feats.shape[0]
378
+ if seq_per_img > 1:
379
+ att_feats, att_masks = utils.repeat_tensors(seq_per_img,
380
+ [att_feats, att_masks]
381
+ )
382
+ else:
383
+ seq_mask = None
384
+
385
+ return att_feats, seq, att_masks, seq_mask
386
+
387
+ def _forward(self, fc_feats, att_feats, seq, att_masks=None):
388
+ if seq.ndim == 3: # B * seq_per_img * seq_len
389
+ seq = seq.reshape(-1, seq.shape[2])
390
+ att_feats, seq, att_masks, seq_mask = self._prepare_feature_forward(att_feats, att_masks, seq)
391
+
392
+ out = self.model(att_feats, seq, att_masks, seq_mask)
393
+
394
+ outputs = self.model.generator(out)
395
+ return outputs
396
+ # return torch.cat([_.unsqueeze(1) for _ in outputs], 1)
397
+
398
+ def core(self, it, fc_feats_ph, att_feats_ph, memory, state, mask):
399
+ """
400
+ state is the precomputed key/value. N_dec x seq_len x d_model
401
+ Note: due to the layer norm, it's not equivalant to stateless,
402
+ but it seems behaving similar
403
+ """
404
+ # state is tokens + past
405
+ if len(state) == 0:
406
+ ys = it.unsqueeze(1)
407
+ # basically empty state, just to let it know to return past
408
+ # The second dim has to be batch_size, for beam search purpose
409
+ past = [fc_feats_ph.new_zeros(self.N_dec * 2, fc_feats_ph.shape[0], 0, self.d_model), # self
410
+ fc_feats_ph.new_zeros(self.N_dec * 2, fc_feats_ph.shape[0], 0, self.d_model)] # src
411
+ # 2 for self attn, 2 for src attn
412
+ else:
413
+ ys = torch.cat([state[0][0], it.unsqueeze(1)], dim=1)
414
+ past = state[1:]
415
+ out, past = self.model.decode(memory, mask,
416
+ ys, # We still feed the full past words, because we need it for position embedding to know the position id
417
+ subsequent_mask(ys.size(1))
418
+ .to(memory.device),
419
+ past=past)
420
+ return out[:, -1], [ys.unsqueeze(0)] + past
captioning/models/utils.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ def repeat_tensors(n, x):
4
+ """
5
+ For a tensor of size Bx..., we repeat it n times, and make it Bnx...
6
+ For collections, do nested repeat
7
+ """
8
+ if torch.is_tensor(x):
9
+ x = x.unsqueeze(1) # Bx1x...
10
+ x = x.expand(-1, n, *([-1]*len(x.shape[2:]))) # Bxnx...
11
+ x = x.reshape(x.shape[0]*n, *x.shape[2:]) # Bnx...
12
+ elif type(x) is list or type(x) is tuple:
13
+ x = [repeat_tensors(n, _) for _ in x]
14
+ return x
15
+
16
+
17
+ def split_tensors(n, x):
18
+ if torch.is_tensor(x):
19
+ assert x.shape[0] % n == 0
20
+ x = x.reshape(x.shape[0] // n, n, *x.shape[1:]).unbind(1)
21
+ elif type(x) is list or type(x) is tuple:
22
+ x = [split_tensors(n, _) for _ in x]
23
+ elif x is None:
24
+ x = [None] * n
25
+ return x
captioning/modules/.DS_Store ADDED
Binary file (6.15 kB). View file
 
captioning/modules/loss_wrapper.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from . import losses
3
+ from ..utils.rewards import init_scorer, get_self_critical_reward
4
+
5
+ class LossWrapper(torch.nn.Module):
6
+ def __init__(self, model, opt):
7
+ super(LossWrapper, self).__init__()
8
+ self.opt = opt
9
+ self.model = model
10
+ if opt.label_smoothing > 0:
11
+ self.crit = losses.LabelSmoothing(smoothing=opt.label_smoothing)
12
+ else:
13
+ self.crit = losses.LanguageModelCriterion()
14
+ self.rl_crit = losses.RewardCriterion()
15
+ self.struc_crit = losses.StructureLosses(opt)
16
+
17
+ def forward(self, fc_feats, att_feats, labels, masks, att_masks, gts, gt_indices,
18
+ sc_flag, struc_flag):
19
+ opt = self.opt
20
+
21
+ out = {}
22
+ if struc_flag:
23
+ if opt.structure_loss_weight < 1:
24
+ lm_loss = self.crit(self.model(fc_feats, att_feats, labels[..., :-1], att_masks), labels[..., 1:], masks[..., 1:])
25
+ else:
26
+ lm_loss = torch.tensor(0).type_as(fc_feats)
27
+ if opt.structure_loss_weight > 0:
28
+ gen_result, sample_logprobs = self.model(fc_feats, att_feats, att_masks,
29
+ opt={'sample_method':opt.train_sample_method,
30
+ 'beam_size':opt.train_beam_size,
31
+ 'output_logsoftmax': opt.struc_use_logsoftmax or opt.structure_loss_type == 'softmax_margin'\
32
+ or not 'margin' in opt.structure_loss_type,
33
+ 'sample_n': opt.train_sample_n},
34
+ mode='sample')
35
+ gts = [gts[_] for _ in gt_indices.tolist()]
36
+ struc_loss = self.struc_crit(sample_logprobs, gen_result, gts)
37
+ else:
38
+ struc_loss = {'loss': torch.tensor(0).type_as(fc_feats),
39
+ 'reward': torch.tensor(0).type_as(fc_feats)}
40
+ loss = (1-opt.structure_loss_weight) * lm_loss + opt.structure_loss_weight * struc_loss['loss']
41
+ out['lm_loss'] = lm_loss
42
+ out['struc_loss'] = struc_loss['loss']
43
+ out['reward'] = struc_loss['reward']
44
+ elif not sc_flag:
45
+ loss = self.crit(self.model(fc_feats, att_feats, labels[..., :-1], att_masks), labels[..., 1:], masks[..., 1:])
46
+ else:
47
+ self.model.eval()
48
+ with torch.no_grad():
49
+ greedy_res, _ = self.model(fc_feats, att_feats, att_masks,
50
+ mode='sample',
51
+ opt={'sample_method': opt.sc_sample_method,
52
+ 'beam_size': opt.sc_beam_size})
53
+ self.model.train()
54
+ gen_result, sample_logprobs = self.model(fc_feats, att_feats, att_masks,
55
+ opt={'sample_method':opt.train_sample_method,
56
+ 'beam_size':opt.train_beam_size,
57
+ 'sample_n': opt.train_sample_n},
58
+ mode='sample')
59
+ gts = [gts[_] for _ in gt_indices.tolist()]
60
+ reward = get_self_critical_reward(greedy_res, gts, gen_result, self.opt)
61
+ reward = torch.from_numpy(reward).to(sample_logprobs)
62
+ loss = self.rl_crit(sample_logprobs, gen_result.data, reward)
63
+ out['reward'] = reward[:,0].mean()
64
+ out['loss'] = loss
65
+ return out
captioning/modules/losses.py ADDED
@@ -0,0 +1,218 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from ..utils.rewards import get_scores, get_self_cider_scores
4
+
5
+ class RewardCriterion(nn.Module):
6
+ def __init__(self):
7
+ super(RewardCriterion, self).__init__()
8
+
9
+ def forward(self, input, seq, reward):
10
+ input = input.gather(2, seq.unsqueeze(2)).squeeze(2)
11
+
12
+ input = input.reshape(-1)
13
+ reward = reward.reshape(-1)
14
+ mask = (seq>0).to(input)
15
+ mask = torch.cat([mask.new(mask.size(0), 1).fill_(1), mask[:, :-1]], 1).reshape(-1)
16
+ output = - input * reward * mask
17
+ output = torch.sum(output) / torch.sum(mask)
18
+
19
+ return output
20
+
21
+ class StructureLosses(nn.Module):
22
+ """
23
+ This loss is inspired by Classical Structured Prediction Losses for Sequence to Sequence Learning (Edunov et al., 2018).
24
+ """
25
+ def __init__(self, opt):
26
+ super(StructureLosses, self).__init__()
27
+ self.opt = opt
28
+ self.loss_type = opt.structure_loss_type
29
+
30
+ def forward(self, input, seq, data_gts):
31
+ """
32
+ Input is either logits or log softmax
33
+ """
34
+ out = {}
35
+
36
+ batch_size = input.size(0)# batch_size = sample_size * seq_per_img
37
+ seq_per_img = batch_size // len(data_gts)
38
+
39
+ assert seq_per_img == self.opt.train_sample_n, seq_per_img
40
+
41
+ mask = (seq>0).to(input)
42
+ mask = torch.cat([mask.new_full((mask.size(0), 1), 1), mask[:, :-1]], 1)
43
+
44
+ scores = get_scores(data_gts, seq, self.opt)
45
+ scores = torch.from_numpy(scores).type_as(input).view(-1, seq_per_img)
46
+ out['reward'] = scores #.mean()
47
+ if self.opt.entropy_reward_weight > 0:
48
+ entropy = - (F.softmax(input, dim=2) * F.log_softmax(input, dim=2)).sum(2).data
49
+ entropy = (entropy * mask).sum(1) / mask.sum(1)
50
+ print('entropy', entropy.mean().item())
51
+ scores = scores + self.opt.entropy_reward_weight * entropy.view(-1, seq_per_img)
52
+ # rescale cost to [0,1]
53
+ costs = - scores
54
+ if self.loss_type == 'risk' or self.loss_type == 'softmax_margin':
55
+ costs = costs - costs.min(1, keepdim=True)[0]
56
+ costs = costs / costs.max(1, keepdim=True)[0]
57
+ # in principle
58
+ # Only risk need such rescale
59
+ # margin should be alright; Let's try.
60
+
61
+ # Gather input: BxTxD -> BxT
62
+ input = input.gather(2, seq.unsqueeze(2)).squeeze(2)
63
+
64
+ if self.loss_type == 'seqnll':
65
+ # input is logsoftmax
66
+ input = input * mask
67
+ input = input.sum(1) / mask.sum(1)
68
+ input = input.view(-1, seq_per_img)
69
+
70
+ target = costs.min(1)[1]
71
+ output = F.cross_entropy(input, target)
72
+ elif self.loss_type == 'risk':
73
+ # input is logsoftmax
74
+ input = input * mask
75
+ input = input.sum(1)
76
+ input = input.view(-1, seq_per_img)
77
+
78
+ output = (F.softmax(input.exp()) * costs).sum(1).mean()
79
+
80
+ # test
81
+ # avg_scores = input
82
+ # probs = F.softmax(avg_scores.exp_())
83
+ # loss = (probs * costs.type_as(probs)).sum() / input.size(0)
84
+ # print(output.item(), loss.item())
85
+
86
+ elif self.loss_type == 'max_margin':
87
+ # input is logits
88
+ input = input * mask
89
+ input = input.sum(1) / mask.sum(1)
90
+ input = input.view(-1, seq_per_img)
91
+ _, __ = costs.min(1, keepdim=True)
92
+ costs_star = _
93
+ input_star = input.gather(1, __)
94
+ output = F.relu(costs - costs_star - input_star + input).max(1)[0] / 2
95
+ output = output.mean()
96
+
97
+ # sanity test
98
+ # avg_scores = input + costs
99
+ # scores_with_high_target = avg_scores.clone()
100
+ # scores_with_high_target.scatter_(1, costs.min(1)[1].view(-1, 1), 1e10)
101
+
102
+ # target_and_offender_index = scores_with_high_target.sort(1, True)[1][:, 0:2]
103
+ # avg_scores = avg_scores.gather(1, target_and_offender_index)
104
+ # target_index = avg_scores.new_zeros(avg_scores.size(0), dtype=torch.long)
105
+ # loss = F.multi_margin_loss(avg_scores, target_index, size_average=True, margin=0)
106
+ # print(loss.item() * 2, output.item())
107
+
108
+ elif self.loss_type == 'multi_margin':
109
+ # input is logits
110
+ input = input * mask
111
+ input = input.sum(1) / mask.sum(1)
112
+ input = input.view(-1, seq_per_img)
113
+ _, __ = costs.min(1, keepdim=True)
114
+ costs_star = _
115
+ input_star = input.gather(1, __)
116
+ output = F.relu(costs - costs_star - input_star + input)
117
+ output = output.mean()
118
+
119
+ # sanity test
120
+ # avg_scores = input + costs
121
+ # loss = F.multi_margin_loss(avg_scores, costs.min(1)[1], margin=0)
122
+ # print(output, loss)
123
+
124
+ elif self.loss_type == 'softmax_margin':
125
+ # input is logsoftmax
126
+ input = input * mask
127
+ input = input.sum(1) / mask.sum(1)
128
+ input = input.view(-1, seq_per_img)
129
+
130
+ input = input + costs
131
+ target = costs.min(1)[1]
132
+ output = F.cross_entropy(input, target)
133
+
134
+ elif self.loss_type == 'real_softmax_margin':
135
+ # input is logits
136
+ # This is what originally defined in Kevin's paper
137
+ # The result should be equivalent to softmax_margin
138
+ input = input * mask
139
+ input = input.sum(1) / mask.sum(1)
140
+ input = input.view(-1, seq_per_img)
141
+
142
+ input = input + costs
143
+ target = costs.min(1)[1]
144
+ output = F.cross_entropy(input, target)
145
+
146
+ elif self.loss_type == 'new_self_critical':
147
+ """
148
+ A different self critical
149
+ Self critical uses greedy decoding score as baseline;
150
+ This setting uses the average score of the rest samples as baseline
151
+ (suppose c1...cn n samples, reward1 = score1 - 1/(n-1)(score2+..+scoren) )
152
+ """
153
+ baseline = (scores.sum(1, keepdim=True) - scores) / (scores.shape[1] - 1)
154
+ scores = scores - baseline
155
+ # self cider used as reward to promote diversity (not working that much in this way)
156
+ if getattr(self.opt, 'self_cider_reward_weight', 0) > 0:
157
+ _scores = get_self_cider_scores(data_gts, seq, self.opt)
158
+ _scores = torch.from_numpy(_scores).type_as(scores).view(-1, 1)
159
+ _scores = _scores.expand_as(scores - 1)
160
+ scores += self.opt.self_cider_reward_weight * _scores
161
+ output = - input * mask * scores.view(-1, 1)
162
+ output = torch.sum(output) / torch.sum(mask)
163
+
164
+ out['loss'] = output
165
+ return out
166
+
167
+ class LanguageModelCriterion(nn.Module):
168
+ def __init__(self):
169
+ super(LanguageModelCriterion, self).__init__()
170
+
171
+ def forward(self, input, target, mask):
172
+ if target.ndim == 3:
173
+ target = target.reshape(-1, target.shape[2])
174
+ mask = mask.reshape(-1, mask.shape[2])
175
+ # truncate to the same size
176
+ target = target[:, :input.size(1)]
177
+ mask = mask[:, :input.size(1)].to(input)
178
+
179
+ output = -input.gather(2, target.unsqueeze(2)).squeeze(2) * mask
180
+ # Average over each token
181
+ output = torch.sum(output) / torch.sum(mask)
182
+
183
+ return output
184
+
185
+ class LabelSmoothing(nn.Module):
186
+ "Implement label smoothing."
187
+ def __init__(self, size=0, padding_idx=0, smoothing=0.0):
188
+ super(LabelSmoothing, self).__init__()
189
+ self.criterion = nn.KLDivLoss(size_average=False, reduce=False)
190
+ # self.padding_idx = padding_idx
191
+ self.confidence = 1.0 - smoothing
192
+ self.smoothing = smoothing
193
+ # self.size = size
194
+ self.true_dist = None
195
+
196
+ def forward(self, input, target, mask):
197
+ if target.ndim == 3:
198
+ target = target.reshape(-1, target.shape[2])
199
+ mask = mask.reshape(-1, mask.shape[2])
200
+ # truncate to the same size
201
+ target = target[:, :input.size(1)]
202
+ mask = mask[:, :input.size(1)]
203
+
204
+ input = input.reshape(-1, input.size(-1))
205
+ target = target.reshape(-1)
206
+ mask = mask.reshape(-1).to(input)
207
+
208
+ # assert x.size(1) == self.size
209
+ self.size = input.size(1)
210
+ # true_dist = x.data.clone()
211
+ true_dist = input.data.clone()
212
+ # true_dist.fill_(self.smoothing / (self.size - 2))
213
+ true_dist.fill_(self.smoothing / (self.size - 1))
214
+ true_dist.scatter_(1, target.data.unsqueeze(1), self.confidence)
215
+ # true_dist[:, self.padding_idx] = 0
216
+ # mask = torch.nonzero(target.data == self.padding_idx)
217
+ # self.true_dist = true_dist
218
+ return (self.criterion(input, true_dist).sum(1) * mask).sum() / mask.sum()
captioning/utils/.DS_Store ADDED
Binary file (6.15 kB). View file
 
captioning/utils/__init__.py ADDED
File without changes
captioning/utils/config.py ADDED
@@ -0,0 +1,153 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
2
+ # Copy from fvcore
3
+
4
+ import logging
5
+ import os
6
+ from typing import Any
7
+ import yaml
8
+ from yacs.config import CfgNode as _CfgNode
9
+
10
+ import io as PathManager
11
+
12
+ BASE_KEY = "_BASE_"
13
+
14
+
15
+ class CfgNode(_CfgNode):
16
+ """
17
+ Our own extended version of :class:`yacs.config.CfgNode`.
18
+ It contains the following extra features:
19
+
20
+ 1. The :meth:`merge_from_file` method supports the "_BASE_" key,
21
+ which allows the new CfgNode to inherit all the attributes from the
22
+ base configuration file.
23
+ 2. Keys that start with "COMPUTED_" are treated as insertion-only
24
+ "computed" attributes. They can be inserted regardless of whether
25
+ the CfgNode is frozen or not.
26
+ 3. With "allow_unsafe=True", it supports pyyaml tags that evaluate
27
+ expressions in config. See examples in
28
+ https://pyyaml.org/wiki/PyYAMLDocumentation#yaml-tags-and-python-types
29
+ Note that this may lead to arbitrary code execution: you must not
30
+ load a config file from untrusted sources before manually inspecting
31
+ the content of the file.
32
+ """
33
+
34
+ @staticmethod
35
+ def load_yaml_with_base(filename, allow_unsafe = False):
36
+ """
37
+ Just like `yaml.load(open(filename))`, but inherit attributes from its
38
+ `_BASE_`.
39
+
40
+ Args:
41
+ filename (str): the file name of the current config. Will be used to
42
+ find the base config file.
43
+ allow_unsafe (bool): whether to allow loading the config file with
44
+ `yaml.unsafe_load`.
45
+
46
+ Returns:
47
+ (dict): the loaded yaml
48
+ """
49
+ with PathManager.open(filename, "r") as f:
50
+ try:
51
+ cfg = yaml.safe_load(f)
52
+ except yaml.constructor.ConstructorError:
53
+ if not allow_unsafe:
54
+ raise
55
+ logger = logging.getLogger(__name__)
56
+ logger.warning(
57
+ "Loading config {} with yaml.unsafe_load. Your machine may "
58
+ "be at risk if the file contains malicious content.".format(
59
+ filename
60
+ )
61
+ )
62
+ f.close()
63
+ with open(filename, "r") as f:
64
+ cfg = yaml.unsafe_load(f)
65
+
66
+ def merge_a_into_b(a, b):
67
+ # merge dict a into dict b. values in a will overwrite b.
68
+ for k, v in a.items():
69
+ if isinstance(v, dict) and k in b:
70
+ assert isinstance(
71
+ b[k], dict
72
+ ), "Cannot inherit key '{}' from base!".format(k)
73
+ merge_a_into_b(v, b[k])
74
+ else:
75
+ b[k] = v
76
+
77
+ if BASE_KEY in cfg:
78
+ base_cfg_file = cfg[BASE_KEY]
79
+ if base_cfg_file.startswith("~"):
80
+ base_cfg_file = os.path.expanduser(base_cfg_file)
81
+ if not any(
82
+ map(base_cfg_file.startswith, ["/", "https://", "http://"])
83
+ ):
84
+ # the path to base cfg is relative to the config file itself.
85
+ base_cfg_file = os.path.join(
86
+ os.path.dirname(filename), base_cfg_file
87
+ )
88
+ base_cfg = CfgNode.load_yaml_with_base(
89
+ base_cfg_file, allow_unsafe=allow_unsafe
90
+ )
91
+ del cfg[BASE_KEY]
92
+
93
+ merge_a_into_b(cfg, base_cfg)
94
+ return base_cfg
95
+ return cfg
96
+
97
+ def merge_from_file(self, cfg_filename, allow_unsafe = False):
98
+ """
99
+ Merge configs from a given yaml file.
100
+
101
+ Args:
102
+ cfg_filename: the file name of the yaml config.
103
+ allow_unsafe: whether to allow loading the config file with
104
+ `yaml.unsafe_load`.
105
+ """
106
+ loaded_cfg = CfgNode.load_yaml_with_base(
107
+ cfg_filename, allow_unsafe=allow_unsafe
108
+ )
109
+ loaded_cfg = type(self)(loaded_cfg)
110
+ self.merge_from_other_cfg(loaded_cfg)
111
+
112
+ # Forward the following calls to base, but with a check on the BASE_KEY.
113
+ def merge_from_other_cfg(self, cfg_other):
114
+ """
115
+ Args:
116
+ cfg_other (CfgNode): configs to merge from.
117
+ """
118
+ assert (
119
+ BASE_KEY not in cfg_other
120
+ ), "The reserved key '{}' can only be used in files!".format(BASE_KEY)
121
+ return super().merge_from_other_cfg(cfg_other)
122
+
123
+ def merge_from_list(self, cfg_list):
124
+ """
125
+ Args:
126
+ cfg_list (list): list of configs to merge from.
127
+ """
128
+ keys = set(cfg_list[0::2])
129
+ assert (
130
+ BASE_KEY not in keys
131
+ ), "The reserved key '{}' can only be used in files!".format(BASE_KEY)
132
+ return super().merge_from_list(cfg_list)
133
+
134
+ def __setattr__(self, name, val):
135
+ if name.startswith("COMPUTED_"):
136
+ if name in self:
137
+ old_val = self[name]
138
+ if old_val == val:
139
+ return
140
+ raise KeyError(
141
+ "Computed attributed '{}' already exists "
142
+ "with a different value! old={}, new={}.".format(
143
+ name, old_val, val
144
+ )
145
+ )
146
+ self[name] = val
147
+ else:
148
+ super().__setattr__(name, val)
149
+
150
+
151
+ if __name__ == '__main__':
152
+ cfg = CfgNode.load_yaml_with_base('configs/updown_long.yml')
153
+ print(cfg)
captioning/utils/div_utils.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from random import uniform
2
+ import numpy as np
3
+ from collections import OrderedDict, defaultdict
4
+ from itertools import tee
5
+ import time
6
+
7
+ # -----------------------------------------------
8
+ def find_ngrams(input_list, n):
9
+ return zip(*[input_list[i:] for i in range(n)])
10
+
11
+ def compute_div_n(caps,n=1):
12
+ aggr_div = []
13
+ for k in caps:
14
+ all_ngrams = set()
15
+ lenT = 0.
16
+ for c in caps[k]:
17
+ tkns = c.split()
18
+ lenT += len(tkns)
19
+ ng = find_ngrams(tkns, n)
20
+ all_ngrams.update(ng)
21
+ aggr_div.append(float(len(all_ngrams))/ (1e-6 + float(lenT)))
22
+ return np.array(aggr_div).mean(), np.array(aggr_div)
23
+
24
+ def compute_global_div_n(caps,n=1):
25
+ aggr_div = []
26
+ all_ngrams = set()
27
+ lenT = 0.
28
+ for k in caps:
29
+ for c in caps[k]:
30
+ tkns = c.split()
31
+ lenT += len(tkns)
32
+ ng = find_ngrams(tkns, n)
33
+ all_ngrams.update(ng)
34
+ if n == 1:
35
+ aggr_div.append(float(len(all_ngrams)))
36
+ else:
37
+ aggr_div.append(float(len(all_ngrams))/ (1e-6 + float(lenT)))
38
+ return aggr_div[0], np.repeat(np.array(aggr_div),len(caps))
captioning/utils/eval_multi.py ADDED
@@ -0,0 +1,218 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import absolute_import
2
+ from __future__ import division
3
+ from __future__ import print_function
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+
8
+ import numpy as np
9
+ import json
10
+ from json import encoder
11
+ import random
12
+ import string
13
+ import time
14
+ import os
15
+ import sys
16
+ from . import misc as utils
17
+ from eval_utils import getCOCO
18
+
19
+ from .div_utils import compute_div_n, compute_global_div_n
20
+
21
+ import sys
22
+ try:
23
+ sys.path.append("coco-caption")
24
+ annFile = 'coco-caption/annotations/captions_val2014.json'
25
+ from pycocotools.coco import COCO
26
+ from pycocoevalcap.eval import COCOEvalCap
27
+ from pycocoevalcap.eval_spice import COCOEvalCapSpice
28
+ from pycocoevalcap.tokenizer.ptbtokenizer import PTBTokenizer
29
+ from pycocoevalcap.bleu.bleu import Bleu
30
+ sys.path.append("cider")
31
+ from pyciderevalcap.cider.cider import Cider
32
+ except:
33
+ print('Warning: requirements for eval_multi not satisfied')
34
+
35
+
36
+ def eval_allspice(dataset, preds_n, model_id, split):
37
+ coco = getCOCO(dataset)
38
+ valids = coco.getImgIds()
39
+
40
+ capsById = {}
41
+ for d in preds_n:
42
+ capsById[d['image_id']] = capsById.get(d['image_id'], []) + [d]
43
+
44
+ # filter results to only those in MSCOCO validation set (will be about a third)
45
+ preds_filt_n = [p for p in preds_n if p['image_id'] in valids]
46
+ print('using %d/%d predictions_n' % (len(preds_filt_n), len(preds_n)))
47
+ cache_path_n = os.path.join('eval_results/', model_id + '_' + split + '_n.json')
48
+ json.dump(preds_filt_n, open(cache_path_n, 'w')) # serialize to temporary json file. Sigh, COCO API...
49
+
50
+ # Eval AllSPICE
51
+ cocoRes_n = coco.loadRes(cache_path_n)
52
+ cocoEvalAllSPICE = COCOEvalCapSpice(coco, cocoRes_n)
53
+ cocoEvalAllSPICE.params['image_id'] = cocoRes_n.getImgIds()
54
+ cocoEvalAllSPICE.evaluate()
55
+
56
+ out = {}
57
+ for metric, score in cocoEvalAllSPICE.eval.items():
58
+ out['All'+metric] = score
59
+
60
+ imgToEvalAllSPICE = cocoEvalAllSPICE.imgToEval
61
+ # collect SPICE_sub_score
62
+ for k in list(imgToEvalAllSPICE.values())[0]['SPICE'].keys():
63
+ if k != 'All':
64
+ out['AllSPICE_'+k] = np.array([v['SPICE'][k]['f'] for v in imgToEvalAllSPICE.values()])
65
+ out['AllSPICE_'+k] = (out['AllSPICE_'+k][out['AllSPICE_'+k]==out['AllSPICE_'+k]]).mean()
66
+ for p in preds_filt_n:
67
+ image_id, caption = p['image_id'], p['caption']
68
+ imgToEvalAllSPICE[image_id]['caption'] = capsById[image_id]
69
+ return {'overall': out, 'imgToEvalAllSPICE': imgToEvalAllSPICE}
70
+
71
+ def eval_oracle(dataset, preds_n, model_id, split):
72
+ cache_path = os.path.join('eval_results/', model_id + '_' + split + '_n.json')
73
+
74
+ coco = getCOCO(dataset)
75
+ valids = coco.getImgIds()
76
+
77
+ capsById = {}
78
+ for d in preds_n:
79
+ capsById[d['image_id']] = capsById.get(d['image_id'], []) + [d]
80
+
81
+ sample_n = capsById[list(capsById.keys())[0]]
82
+ for i in range(len(capsById[list(capsById.keys())[0]])):
83
+ preds = [_[i] for _ in capsById.values()]
84
+
85
+ json.dump(preds, open(cache_path, 'w')) # serialize to temporary json file. Sigh, COCO API...
86
+
87
+ cocoRes = coco.loadRes(cache_path)
88
+ cocoEval = COCOEvalCap(coco, cocoRes)
89
+ cocoEval.params['image_id'] = cocoRes.getImgIds()
90
+ cocoEval.evaluate()
91
+
92
+ imgToEval = cocoEval.imgToEval
93
+ for img_id in capsById.keys():
94
+ tmp = imgToEval[img_id]
95
+ for k in tmp['SPICE'].keys():
96
+ if k != 'All':
97
+ tmp['SPICE_'+k] = tmp['SPICE'][k]['f']
98
+ if tmp['SPICE_'+k] != tmp['SPICE_'+k]: # nan
99
+ tmp['SPICE_'+k] = -100
100
+ tmp['SPICE'] = tmp['SPICE']['All']['f']
101
+ if tmp['SPICE'] != tmp['SPICE']: tmp['SPICE'] = -100
102
+ capsById[img_id][i]['scores'] = imgToEval[img_id]
103
+
104
+ out = {'overall': {}, 'ImgToEval': {}}
105
+ for img_id in capsById.keys():
106
+ out['ImgToEval'][img_id] = {}
107
+ for metric in capsById[img_id][0]['scores'].keys():
108
+ if metric == 'image_id': continue
109
+ out['ImgToEval'][img_id]['oracle_'+metric] = max([_['scores'][metric] for _ in capsById[img_id]])
110
+ out['ImgToEval'][img_id]['avg_'+metric] = sum([_['scores'][metric] for _ in capsById[img_id]]) / len(capsById[img_id])
111
+ out['ImgToEval'][img_id]['captions'] = capsById[img_id]
112
+ for metric in list(out['ImgToEval'].values())[0].keys():
113
+ if metric == 'captions':
114
+ continue
115
+ tmp = np.array([_[metric] for _ in out['ImgToEval'].values()])
116
+ tmp = tmp[tmp!=-100]
117
+ out['overall'][metric] = tmp.mean()
118
+
119
+ return out
120
+
121
+ def eval_div_stats(dataset, preds_n, model_id, split):
122
+ tokenizer = PTBTokenizer()
123
+
124
+ capsById = {}
125
+ for i, d in enumerate(preds_n):
126
+ d['id'] = i
127
+ capsById[d['image_id']] = capsById.get(d['image_id'], []) + [d]
128
+
129
+ n_caps_perimg = len(capsById[list(capsById.keys())[0]])
130
+ print(n_caps_perimg)
131
+ _capsById = capsById # save the untokenized version
132
+ capsById = tokenizer.tokenize(capsById)
133
+
134
+ div_1, adiv_1 = compute_div_n(capsById,1)
135
+ div_2, adiv_2 = compute_div_n(capsById,2)
136
+
137
+ globdiv_1, _= compute_global_div_n(capsById,1)
138
+
139
+ print('Diversity Statistics are as follows: \n Div1: %.2f, Div2: %.2f, gDiv1: %d\n'%(div_1,div_2, globdiv_1))
140
+
141
+ # compute mbleu
142
+ scorer = Bleu(4)
143
+ all_scrs = []
144
+ scrperimg = np.zeros((n_caps_perimg, len(capsById)))
145
+
146
+ for i in range(n_caps_perimg):
147
+ tempRefsById = {}
148
+ candsById = {}
149
+ for k in capsById:
150
+ tempRefsById[k] = capsById[k][:i] + capsById[k][i+1:]
151
+ candsById[k] = [capsById[k][i]]
152
+
153
+ score, scores = scorer.compute_score(tempRefsById, candsById)
154
+ all_scrs.append(score)
155
+ scrperimg[i,:] = scores[1]
156
+
157
+ all_scrs = np.array(all_scrs)
158
+
159
+ out = {}
160
+ out['overall'] = {'Div1': div_1, 'Div2': div_2, 'gDiv1': globdiv_1}
161
+ for k, score in zip(range(4), all_scrs.mean(axis=0).tolist()):
162
+ out['overall'].update({'mBLeu_%d'%(k+1): score})
163
+ imgToEval = {}
164
+ for i,imgid in enumerate(capsById.keys()):
165
+ imgToEval[imgid] = {'mBleu_2' : scrperimg[:,i].mean()}
166
+ imgToEval[imgid]['individuals'] = []
167
+ for j, d in enumerate(_capsById[imgid]):
168
+ imgToEval[imgid]['individuals'].append(preds_n[d['id']])
169
+ imgToEval[imgid]['individuals'][-1]['mBleu_2'] = scrperimg[j,i]
170
+ out['ImgToEval'] = imgToEval
171
+
172
+ print('Mean mutual Bleu scores on this set is:\nmBLeu_1, mBLeu_2, mBLeu_3, mBLeu_4')
173
+ print(all_scrs.mean(axis=0))
174
+
175
+ return out
176
+
177
+ def eval_self_cider(dataset, preds_n, model_id, split):
178
+ cache_path = os.path.join('eval_results/', model_id + '_' + split + '_n.json')
179
+
180
+ coco = getCOCO(dataset)
181
+ valids = coco.getImgIds()
182
+
183
+ # Get Cider_scorer
184
+ Cider_scorer = Cider(df='corpus')
185
+
186
+ tokenizer = PTBTokenizer()
187
+ gts = {}
188
+ for imgId in valids:
189
+ gts[imgId] = coco.imgToAnns[imgId]
190
+ gts = tokenizer.tokenize(gts)
191
+
192
+ for imgId in valids:
193
+ Cider_scorer.cider_scorer += (None, gts[imgId])
194
+ Cider_scorer.cider_scorer.compute_doc_freq()
195
+ Cider_scorer.cider_scorer.ref_len = np.log(float(len(Cider_scorer.cider_scorer.crefs)))
196
+
197
+ # Prepare captions
198
+ capsById = {}
199
+ for d in preds_n:
200
+ capsById[d['image_id']] = capsById.get(d['image_id'], []) + [d]
201
+
202
+ capsById = tokenizer.tokenize(capsById)
203
+ imgIds = list(capsById.keys())
204
+ scores = Cider_scorer.my_self_cider([capsById[_] for _ in imgIds])
205
+
206
+ def get_div(eigvals):
207
+ eigvals = np.clip(eigvals, 0, None)
208
+ return -np.log(np.sqrt(eigvals[-1]) / (np.sqrt(eigvals).sum())) / np.log(len(eigvals))
209
+ sc_scores = [get_div(np.linalg.eigvalsh(_/10)) for _ in scores]
210
+ score = np.mean(np.array(sc_scores))
211
+
212
+ imgToEval = {}
213
+ for i, image_id in enumerate(imgIds):
214
+ imgToEval[image_id] = {'self_cider': sc_scores[i], 'self_cider_mat': scores[i].tolist()}
215
+ return {'overall': {'self_cider': score}, 'imgToEval': imgToEval}
216
+
217
+
218
+ return score
captioning/utils/eval_utils.py ADDED
@@ -0,0 +1,323 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import absolute_import
2
+ from __future__ import division
3
+ from __future__ import print_function
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+
9
+ import numpy as np
10
+ import json
11
+ from json import encoder
12
+ import random
13
+ import string
14
+ import time
15
+ import os
16
+ import sys
17
+ from . import misc as utils
18
+
19
+ # sys.path.insert(0, os.getcwd())
20
+
21
+ # sys.path.append("coco-caption")
22
+
23
+ # load coco-caption if available
24
+
25
+ from coco_caption.pycocotools.coco import COCO
26
+ from coco_caption.pycocoevalcap.eval import COCOEvalCap
27
+
28
+ # try:
29
+ # # sys.path.append("coco-caption")
30
+ # # from pycocotools.coco import COCO
31
+ # # from pycocoevalcap.eval import COCOEvalCap
32
+ # from coco_caption.pycocotools.coco import COCO
33
+ # from coco_caption.pycocoevalcap.eval import COCOEvalCap
34
+ # except:
35
+ # print('Warning: coco-caption not available')
36
+
37
+ bad_endings = ['a','an','the','in','for','at','of','with','before','after','on','upon','near','to','is','are','am']
38
+ bad_endings += ['UNK', 'has', 'and', 'more']
39
+
40
+
41
+ def count_bad(sen):
42
+ sen = sen.split(' ')
43
+ if sen[-1] in bad_endings:
44
+ return 1
45
+ else:
46
+ return 0
47
+
48
+
49
+ def getCOCO(dataset):
50
+ if 'coco' in dataset:
51
+ annFile = 'coco-caption/annotations/captions_val2014.json'
52
+ elif 'flickr30k' in dataset or 'f30k' in dataset:
53
+ annFile = 'data/f30k_captions4eval.json'
54
+ # elif 'relative' in dataset:
55
+ # annFile = 'data/dress/features_simulator/caption_relative.json'
56
+ elif 'dress' in dataset:
57
+ annFile = 'data/dress/features_simulator/caption_relative.json'
58
+ elif 'shirt' in dataset:
59
+ annFile = 'data/shirt/features_simulator/caption_relative.json'
60
+ elif 'toptee' in dataset:
61
+ annFile = 'data/toptee/features_simulator/caption_relative.json'
62
+ elif 'fashion-gen' in dataset:
63
+ annFile = 'data/fashion-gen/features_simulator/caption_direct.json'
64
+ elif 'shoe' in dataset:
65
+ annFile = 'data/shoe/features_simulator/caption_relative.json'
66
+ return COCO(annFile)
67
+
68
+
69
+ def language_eval(dataset, preds, preds_n, eval_kwargs, split):
70
+ model_id = eval_kwargs['id']
71
+ eval_oracle = eval_kwargs.get('eval_oracle', 0)
72
+
73
+ # create output dictionary
74
+ out = {}
75
+
76
+ if len(preds_n) > 0:
77
+ # vocab size and novel sentences
78
+ if 'coco' in dataset:
79
+ dataset_file = 'data/dataset_coco.json'
80
+ elif 'flickr30k' in dataset or 'f30k' in dataset:
81
+ dataset_file = 'data/dataset_flickr30k.json'
82
+ # elif 'relative' in dataset:
83
+ # dataset_file = 'data/dress/features_simulator/caption_relative.json'
84
+ elif 'dress' in dataset:
85
+ annFile = 'data/dress/features_simulator/caption_relative.json'
86
+ elif 'shirt' in dataset:
87
+ annFile = 'data/shirt/features_simulator/caption_relative.json'
88
+ elif 'toptee' in dataset:
89
+ annFile = 'data/toptee/features_simulator/caption_relative.json'
90
+ elif 'fashion-gen' in dataset:
91
+ annFile = 'data/fashion-gen/features_simulator/caption_direct.json'
92
+ elif 'shoe' in dataset:
93
+ annFile = 'data/shoe/features_simulator/caption_relative.json'
94
+ training_sentences = set([' '.join(__['tokens']) for _ in json.load(open(dataset_file))['images'] if not _['split'] in ['val', 'test'] for __ in _['sentences']])
95
+ generated_sentences = set([_['caption'] for _ in preds_n])
96
+ novels = generated_sentences - training_sentences
97
+ out['novel_sentences'] = float(len(novels)) / len(preds_n)
98
+ tmp = [_.split() for _ in generated_sentences]
99
+ words = []
100
+ for _ in tmp:
101
+ words += _
102
+ out['vocab_size'] = len(set(words))
103
+
104
+ # encoder.FLOAT_REPR = lambda o: format(o, '.3f')
105
+
106
+ # cache_path = os.path.join('eval_results/', '.cache_'+ model_id + '_' + split + '.json')\
107
+ cache_path = os.path.join('results/log_'+eval_kwargs['topic']+'_'+model_id+'/eval_results_'+eval_kwargs['topic']+'/', '.cache_'+ model_id + '_' + split + '.json')
108
+
109
+ coco = getCOCO(dataset)
110
+ valids = coco.getImgIds()
111
+
112
+ # filter results to only those in MSCOCO validation set
113
+ preds_filt = [p for p in preds if p['image_id'] in valids]
114
+ mean_perplexity = sum([_['perplexity'] for _ in preds_filt]) / len(preds_filt)
115
+ mean_entropy = sum([_['entropy'] for _ in preds_filt]) / len(preds_filt)
116
+ print('using %d/%d predictions' % (len(preds_filt), len(preds)))
117
+ json.dump(preds_filt, open(cache_path, 'w')) # serialize to temporary json file. Sigh, COCO API...
118
+
119
+ cocoRes = coco.loadRes(cache_path)
120
+ cocoEval = COCOEvalCap(coco, cocoRes)
121
+ cocoEval.params['image_id'] = cocoRes.getImgIds()
122
+ cocoEval.evaluate()
123
+
124
+ for metric, score in cocoEval.eval.items():
125
+ out[metric] = score
126
+ # Add mean perplexity
127
+ out['perplexity'] = mean_perplexity
128
+ out['entropy'] = mean_entropy
129
+
130
+ imgToEval = cocoEval.imgToEval
131
+ for k in list(imgToEval.values())[0]['SPICE'].keys():
132
+ if k != 'All':
133
+ out['SPICE_'+k] = np.array([v['SPICE'][k]['f'] for v in imgToEval.values()])
134
+ out['SPICE_'+k] = (out['SPICE_'+k][out['SPICE_'+k]==out['SPICE_'+k]]).mean()
135
+ for p in preds_filt:
136
+ image_id, caption = p['image_id'], p['caption']
137
+ imgToEval[image_id]['caption'] = caption
138
+
139
+ if len(preds_n) > 0:
140
+ from . import eval_multi
141
+ # cache_path_n = os.path.join('eval_results/', '.cache_'+ model_id + '_' + split + '_n.json')
142
+ cache_path_n = os.path.join('results/log_'+eval_kwargs['topic']+'_'+model_id+'/eval_results_'+eval_kwargs['topic']+'/', '.cache_'+ model_id + '_' + split + '_n.json')
143
+ allspice = eval_multi.eval_allspice(dataset, preds_n, model_id, split)
144
+ out.update(allspice['overall'])
145
+ div_stats = eval_multi.eval_div_stats(dataset, preds_n, model_id, split)
146
+ out.update(div_stats['overall'])
147
+ if eval_oracle:
148
+ oracle = eval_multi.eval_oracle(dataset, preds_n, model_id, split)
149
+ out.update(oracle['overall'])
150
+ else:
151
+ oracle = None
152
+ self_cider = eval_multi.eval_self_cider(dataset, preds_n, model_id, split)
153
+ out.update(self_cider['overall'])
154
+ with open(cache_path_n, 'w') as outfile:
155
+ json.dump({'allspice': allspice, 'div_stats': div_stats, 'oracle': oracle, 'self_cider': self_cider}, outfile)
156
+
157
+ out['bad_count_rate'] = sum([count_bad(_['caption']) for _ in preds_filt]) / float(len(preds_filt))
158
+ # outfile_path = os.path.join('eval_results/', model_id + '_' + split + '.json')
159
+ outfile_path = os.path.join('results/log_'+eval_kwargs['topic']+'_'+model_id+'/eval_results_'+eval_kwargs['topic']+'/', model_id + '_' + split + '.json')
160
+ with open(outfile_path, 'w') as outfile:
161
+ json.dump({'overall': out, 'imgToEval': imgToEval}, outfile)
162
+
163
+ return out
164
+
165
+ def eval_split(model, crit, loader, eval_kwargs={}):
166
+ verbose = eval_kwargs.get('verbose', True)
167
+ verbose_beam = eval_kwargs.get('verbose_beam', 0)
168
+ verbose_loss = eval_kwargs.get('verbose_loss', 1)
169
+ num_images = eval_kwargs.get('num_images', eval_kwargs.get('val_images_use', -1))
170
+ split = eval_kwargs.get('split', 'val')
171
+ lang_eval = eval_kwargs.get('language_eval', 0)
172
+ dataset = eval_kwargs.get('dataset', 'coco')
173
+ beam_size = eval_kwargs.get('beam_size', 1)
174
+ sample_n = eval_kwargs.get('sample_n', 1)
175
+ remove_bad_endings = eval_kwargs.get('remove_bad_endings', 1)
176
+ os.environ["REMOVE_BAD_ENDINGS"] = str(remove_bad_endings) # Use this nasty way to make other code clean since it's a global configuration
177
+ device = eval_kwargs.get('device', 'cuda')
178
+
179
+ # Make sure in the evaluation mode
180
+ model.eval()
181
+
182
+ loader.reset_iterator(split)
183
+
184
+ n = 0
185
+ loss = 0
186
+ loss_sum = 0
187
+ loss_evals = 1e-8
188
+ predictions = []
189
+ n_predictions = [] # when sample_n > 1
190
+ while True:
191
+ data = loader.get_batch(split)
192
+ n = n + len(data['infos'])
193
+
194
+ tmp = [data['fc_feats'], data['att_feats'], data['labels'], data['masks'], data['att_masks']]
195
+ tmp = [_.to(device) if _ is not None else _ for _ in tmp]
196
+ fc_feats, att_feats, labels, masks, att_masks = tmp
197
+
198
+ if labels is not None and verbose_loss:
199
+ # forward the model to get loss
200
+ with torch.no_grad():
201
+ loss = crit(model(fc_feats, att_feats, labels[..., :-1], att_masks), labels[..., 1:], masks[..., 1:]).item()
202
+ loss_sum = loss_sum + loss
203
+ loss_evals = loss_evals + 1
204
+
205
+ # forward the model to also get generated samples for each image
206
+ with torch.no_grad():
207
+ tmp_eval_kwargs = eval_kwargs.copy()
208
+ tmp_eval_kwargs.update({'sample_n': 1})
209
+ seq, seq_logprobs = model(fc_feats, att_feats, att_masks, opt=tmp_eval_kwargs, mode='sample')
210
+ seq = seq.data
211
+
212
+ entropy = - (F.softmax(seq_logprobs, dim=2) * seq_logprobs).sum(2).sum(1) / ((seq>0).to(seq_logprobs).sum(1)+1)
213
+ perplexity = - seq_logprobs.gather(2, seq.unsqueeze(2)).squeeze(2).sum(1) / ((seq>0).to(seq_logprobs).sum(1)+1)
214
+
215
+ # Print beam search
216
+ if beam_size > 1 and verbose_beam:
217
+ for i in range(fc_feats.shape[0]):
218
+ print('\n'.join([utils.decode_sequence(model.vocab, _['seq'].unsqueeze(0))[0] for _ in model.done_beams[i]]))
219
+ print('--' * 10)
220
+ sents = utils.decode_sequence(model.vocab, seq)
221
+
222
+ for k, sent in enumerate(sents):
223
+ entry = {'image_id': data['infos'][k]['id'], 'caption': sent, 'perplexity': perplexity[k].item(), 'entropy': entropy[k].item()}
224
+ if eval_kwargs.get('dump_path', 0) == 1:
225
+ entry['file_name'] = data['infos'][k]['file_path']
226
+ predictions.append(entry)
227
+ if eval_kwargs.get('dump_images', 0) == 1:
228
+ # dump the raw image to vis/ folder
229
+ cmd = 'cp "' + os.path.join(eval_kwargs['image_root'], data['infos'][k]['file_path']) + '" vis/imgs/img' + str(len(predictions)) + '.jpg' # bit gross
230
+ print(cmd)
231
+ os.system(cmd)
232
+
233
+ if verbose:
234
+ print('image %s: %s' %(entry['image_id'], entry['caption']))
235
+
236
+ if sample_n > 1:
237
+ eval_split_n(model, n_predictions, [fc_feats, att_feats, att_masks, data], eval_kwargs)
238
+
239
+ # ix0 = data['bounds']['it_pos_now']
240
+ ix1 = data['bounds']['it_max']
241
+ if num_images != -1:
242
+ ix1 = min(ix1, num_images)
243
+ else:
244
+ num_images = ix1
245
+ for i in range(n - ix1):
246
+ predictions.pop()
247
+
248
+ if verbose:
249
+ print('evaluating validation preformance... %d/%d (%f)' %(n, ix1, loss))
250
+
251
+ if num_images >= 0 and n >= num_images:
252
+ break
253
+
254
+ lang_stats = None
255
+ if len(n_predictions) > 0 and 'perplexity' in n_predictions[0]:
256
+ n_predictions = sorted(n_predictions, key=lambda x: x['perplexity'])
257
+ # if not os.path.isdir('eval_results'):
258
+ # os.mkdir('eval_results')
259
+ if not os.path.isdir('results/log_'+eval_kwargs['topic']+'_'+eval_kwargs['id']+'/eval_results_'+eval_kwargs['topic']):
260
+ os.mkdir('results/log_'+eval_kwargs['topic']+'_'+eval_kwargs['id']+'/eval_results_'+eval_kwargs['topic'])
261
+ # torch.save((predictions, n_predictions), os.path.join('eval_results/', '.saved_pred_'+ eval_kwargs['id'] + '_' + split + '.pth'))
262
+ torch.save((predictions, n_predictions), os.path.join('results/log_'+eval_kwargs['topic']+'_'+eval_kwargs['id']+'/eval_results_'+eval_kwargs['topic']+'/', '.saved_pred_'+ eval_kwargs['id'] + '_' + split + '.pth'))
263
+ if lang_eval == 1:
264
+ lang_stats = language_eval(dataset, predictions, n_predictions, eval_kwargs, split)
265
+
266
+ # Switch back to training mode
267
+ model.train()
268
+ return loss_sum/loss_evals, predictions, lang_stats
269
+
270
+
271
+ # Only run when sample_n > 0
272
+ def eval_split_n(model, n_predictions, input_data, eval_kwargs={}):
273
+ verbose = eval_kwargs.get('verbose', True)
274
+ beam_size = eval_kwargs.get('beam_size', 1)
275
+ sample_n = eval_kwargs.get('sample_n', 1)
276
+ sample_n_method = eval_kwargs.get('sample_n_method', 'sample')
277
+
278
+ fc_feats, att_feats, att_masks, data = input_data
279
+
280
+ tmp_eval_kwargs = eval_kwargs.copy()
281
+ if sample_n_method == 'bs':
282
+ # case 1 sample_n == beam size
283
+ tmp_eval_kwargs.update({'sample_n': 1, 'beam_size': sample_n, 'group_size': 1}) # randomness from softmax
284
+ with torch.no_grad():
285
+ model(fc_feats, att_feats, att_masks, opt=tmp_eval_kwargs, mode='sample')
286
+ for k in range(fc_feats.shape[0]):
287
+ _sents = utils.decode_sequence(model.vocab, torch.stack([model.done_beams[k][_]['seq'] for _ in range(sample_n)]))
288
+ for sent in _sents:
289
+ entry = {'image_id': data['infos'][k]['id'], 'caption': sent}
290
+ n_predictions.append(entry)
291
+ # case 2 sample / gumbel / topk sampling/ nucleus sampling
292
+ elif sample_n_method == 'sample' or \
293
+ sample_n_method == 'gumbel' or \
294
+ sample_n_method.startswith('top'):
295
+ tmp_eval_kwargs.update({'sample_n': sample_n, 'sample_method': sample_n_method, 'beam_size': 1}) # randomness from sample
296
+ with torch.no_grad():
297
+ _seq, _sampleLogprobs = model(fc_feats, att_feats, att_masks, opt=tmp_eval_kwargs, mode='sample')
298
+ _sents = utils.decode_sequence(model.vocab, _seq)
299
+ _perplexity = - _sampleLogprobs.gather(2, _seq.unsqueeze(2)).squeeze(2).sum(1) / ((_seq>0).to(_sampleLogprobs).sum(1)+1)
300
+ for k, sent in enumerate(_sents):
301
+ entry = {'image_id': data['infos'][k // sample_n]['id'], 'caption': sent, 'perplexity': _perplexity[k].item()}
302
+ n_predictions.append(entry)
303
+ elif sample_n_method == 'dbs':
304
+ # Use diverse beam search
305
+ tmp_eval_kwargs.update({'beam_size': sample_n * beam_size, 'group_size': sample_n}) # randomness from softmax
306
+ with torch.no_grad():
307
+ model(fc_feats, att_feats, att_masks, opt=tmp_eval_kwargs, mode='sample')
308
+ for k in range(loader.batch_size):
309
+ _sents = utils.decode_sequence(model.vocab, torch.stack([model.done_beams[k][_]['seq'] for _ in range(0, sample_n*beam_size, beam_size)]))
310
+ for sent in _sents:
311
+ entry = {'image_id': data['infos'][k]['id'], 'caption': sent}
312
+ n_predictions.append(entry)
313
+ else:
314
+ tmp_eval_kwargs.update({'sample_method': sample_n_method[1:], 'group_size': sample_n, 'beam_size':1}) # randomness from softmax
315
+ with torch.no_grad():
316
+ _seq, _sampleLogprobs = model(fc_feats, att_feats, att_masks, opt=tmp_eval_kwargs, mode='sample')
317
+ _sents = utils.decode_sequence(model.vocab, _seq)
318
+ for k, sent in enumerate(_sents):
319
+ entry = {'image_id': data['infos'][k // sample_n]['id'], 'caption': sent}
320
+ n_predictions.append(entry)
321
+ if verbose:
322
+ for entry in sorted(n_predictions[-fc_feats.shape[0] * sample_n:], key=lambda x: x['image_id']):
323
+ print('image %s: %s' %(entry['image_id'], entry['caption']))
captioning/utils/misc.py ADDED
@@ -0,0 +1,249 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import absolute_import
2
+ from __future__ import division
3
+ from __future__ import print_function
4
+
5
+ import collections
6
+ import torch
7
+ import torch.nn as nn
8
+ import numpy as np
9
+ import torch.optim as optim
10
+ import os
11
+
12
+ import torch.nn.functional as F
13
+
14
+ import six
15
+ from six.moves import cPickle
16
+
17
+ bad_endings = ['a','an','the','in','for','at','of','with','before','after','on','upon','near','to','is','are','am']
18
+ bad_endings += ['UNK', 'has', 'and', 'more']
19
+
20
+ def pickle_load(f):
21
+ """ Load a pickle.
22
+ Parameters
23
+ ----------
24
+ f: file-like object
25
+ """
26
+ if six.PY3:
27
+ return cPickle.load(f, encoding='latin-1')
28
+ else:
29
+ return cPickle.load(f)
30
+
31
+
32
+ def pickle_dump(obj, f):
33
+ """ Dump a pickle.
34
+ Parameters
35
+ ----------
36
+ obj: pickled object
37
+ f: file-like object
38
+ """
39
+ if six.PY3:
40
+ return cPickle.dump(obj, f, protocol=2)
41
+ else:
42
+ return cPickle.dump(obj, f)
43
+
44
+
45
+ # modified from https://github.com/facebookresearch/detectron2/blob/master/detectron2/utils/comm.py
46
+ def serialize_to_tensor(data):
47
+ device = torch.device("cpu")
48
+
49
+ buffer = cPickle.dumps(data)
50
+ storage = torch.ByteStorage.from_buffer(buffer)
51
+ tensor = torch.ByteTensor(storage).to(device=device)
52
+ return tensor
53
+
54
+
55
+ def deserialize(tensor):
56
+ buffer = tensor.cpu().numpy().tobytes()
57
+ return cPickle.loads(buffer)
58
+
59
+
60
+ # Input: seq, N*D numpy array, with element 0 .. vocab_size. 0 is END token.
61
+ def decode_sequence(ix_to_word, seq):
62
+ N, D = seq.size()
63
+ out = []
64
+ for i in range(N):
65
+ txt = ''
66
+ for j in range(D):
67
+ ix = seq[i,j]
68
+ if ix > 0 :
69
+ if j >= 1:
70
+ txt = txt + ' '
71
+ txt = txt + ix_to_word[str(ix.item())]
72
+ else:
73
+ break
74
+ if int(os.getenv('REMOVE_BAD_ENDINGS', '0')):
75
+ flag = 0
76
+ words = txt.split(' ')
77
+ for j in range(len(words)):
78
+ if words[-j-1] not in bad_endings:
79
+ flag = -j
80
+ break
81
+ txt = ' '.join(words[0:len(words)+flag])
82
+ out.append(txt.replace('@@ ', ''))
83
+ return out
84
+
85
+
86
+ def save_checkpoint(opt, model, infos, optimizer, histories=None, append=''):
87
+ if len(append) > 0:
88
+ append = '_' + append
89
+ # if checkpoint_path doesn't exist
90
+ if not os.path.isdir(opt.checkpoint_path):
91
+ os.makedirs(opt.checkpoint_path)
92
+ checkpoint_path = os.path.join(opt.checkpoint_path, 'model%s.pth' %(append))
93
+ torch.save(model.state_dict(), checkpoint_path)
94
+ print("model saved to {}".format(checkpoint_path))
95
+ optimizer_path = os.path.join(opt.checkpoint_path, 'optimizer%s.pth' %(append))
96
+ torch.save(optimizer.state_dict(), optimizer_path)
97
+ with open(os.path.join(opt.checkpoint_path, 'infos%s.pkl' %(append)), 'wb') as f:
98
+ pickle_dump(infos, f)
99
+ if histories:
100
+ with open(os.path.join(opt.checkpoint_path, 'histories%s.pkl' %(append)), 'wb') as f:
101
+ pickle_dump(histories, f)
102
+
103
+
104
+ def set_lr(optimizer, lr):
105
+ for group in optimizer.param_groups:
106
+ group['lr'] = lr
107
+
108
+ def get_lr(optimizer):
109
+ for group in optimizer.param_groups:
110
+ return group['lr']
111
+
112
+
113
+ def build_optimizer(params, opt):
114
+ if opt.optim == 'rmsprop':
115
+ return optim.RMSprop(params, opt.learning_rate, opt.optim_alpha, opt.optim_epsilon, weight_decay=opt.weight_decay)
116
+ elif opt.optim == 'adagrad':
117
+ return optim.Adagrad(params, opt.learning_rate, weight_decay=opt.weight_decay)
118
+ elif opt.optim == 'sgd':
119
+ return optim.SGD(params, opt.learning_rate, weight_decay=opt.weight_decay)
120
+ elif opt.optim == 'sgdm':
121
+ return optim.SGD(params, opt.learning_rate, opt.optim_alpha, weight_decay=opt.weight_decay)
122
+ elif opt.optim == 'sgdmom':
123
+ return optim.SGD(params, opt.learning_rate, opt.optim_alpha, weight_decay=opt.weight_decay, nesterov=True)
124
+ elif opt.optim == 'adam':
125
+ return optim.Adam(params, opt.learning_rate, (opt.optim_alpha, opt.optim_beta), opt.optim_epsilon, weight_decay=opt.weight_decay)
126
+ elif opt.optim == 'adamw':
127
+ return optim.AdamW(params, opt.learning_rate, (opt.optim_alpha, opt.optim_beta), opt.optim_epsilon, weight_decay=opt.weight_decay)
128
+ else:
129
+ raise Exception("bad option opt.optim: {}".format(opt.optim))
130
+
131
+
132
+ def penalty_builder(penalty_config):
133
+ if penalty_config == '':
134
+ return lambda x,y: y
135
+ pen_type, alpha = penalty_config.split('_')
136
+ alpha = float(alpha)
137
+ if pen_type == 'wu':
138
+ return lambda x,y: length_wu(x,y,alpha)
139
+ if pen_type == 'avg':
140
+ return lambda x,y: length_average(x,y,alpha)
141
+
142
+ def length_wu(length, logprobs, alpha=0.):
143
+ """
144
+ NMT length re-ranking score from
145
+ "Google's Neural Machine Translation System" :cite:`wu2016google`.
146
+ """
147
+
148
+ modifier = (((5 + length) ** alpha) /
149
+ ((5 + 1) ** alpha))
150
+ return (logprobs / modifier)
151
+
152
+ def length_average(length, logprobs, alpha=0.):
153
+ """
154
+ Returns the average probability of tokens in a sequence.
155
+ """
156
+ return logprobs / length
157
+
158
+
159
+ class NoamOpt(object):
160
+ "Optim wrapper that implements rate."
161
+ def __init__(self, model_size, factor, warmup, optimizer):
162
+ self.optimizer = optimizer
163
+ self._step = 0
164
+ self.warmup = warmup
165
+ self.factor = factor
166
+ self.model_size = model_size
167
+ self._rate = 0
168
+
169
+ def step(self):
170
+ "Update parameters and rate"
171
+ self._step += 1
172
+ rate = self.rate()
173
+ for p in self.optimizer.param_groups:
174
+ p['lr'] = rate
175
+ self._rate = rate
176
+ self.optimizer.step()
177
+
178
+ def rate(self, step = None):
179
+ "Implement `lrate` above"
180
+ if step is None:
181
+ step = self._step
182
+ return self.factor * \
183
+ (self.model_size ** (-0.5) *
184
+ min(step ** (-0.5), step * self.warmup ** (-1.5)))
185
+
186
+ def __getattr__(self, name):
187
+ return getattr(self.optimizer, name)
188
+
189
+ def state_dict(self):
190
+ state_dict = self.optimizer.state_dict()
191
+ state_dict['_step'] = self._step
192
+ return state_dict
193
+
194
+ def load_state_dict(self, state_dict):
195
+ if '_step' in state_dict:
196
+ self._step = state_dict['_step']
197
+ del state_dict['_step']
198
+ self.optimizer.load_state_dict(state_dict)
199
+
200
+ class ReduceLROnPlateau(object):
201
+ "Optim wrapper that implements rate."
202
+ def __init__(self, optimizer, mode='min', factor=0.1, patience=10, verbose=False, threshold=0.0001, threshold_mode='rel', cooldown=0, min_lr=0, eps=1e-08):
203
+ self.scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode, factor, patience, verbose, threshold, threshold_mode, cooldown, min_lr, eps)
204
+ self.optimizer = optimizer
205
+ self.current_lr = get_lr(optimizer)
206
+
207
+ def step(self):
208
+ "Update parameters and rate"
209
+ self.optimizer.step()
210
+
211
+ def scheduler_step(self, val):
212
+ self.scheduler.step(val)
213
+ self.current_lr = get_lr(self.optimizer)
214
+
215
+ def state_dict(self):
216
+ return {'current_lr':self.current_lr,
217
+ 'scheduler_state_dict': self.scheduler.state_dict(),
218
+ 'optimizer_state_dict': self.optimizer.state_dict()}
219
+
220
+ def load_state_dict(self, state_dict):
221
+ if 'current_lr' not in state_dict:
222
+ # it's normal optimizer
223
+ self.optimizer.load_state_dict(state_dict)
224
+ set_lr(self.optimizer, self.current_lr) # use the lr fromt the option
225
+ else:
226
+ # it's a schduler
227
+ self.current_lr = state_dict['current_lr']
228
+ self.scheduler.load_state_dict(state_dict['scheduler_state_dict'])
229
+ self.optimizer.load_state_dict(state_dict['optimizer_state_dict'])
230
+ # current_lr is actually useless in this case
231
+
232
+ def rate(self, step = None):
233
+ "Implement `lrate` above"
234
+ if step is None:
235
+ step = self._step
236
+ return self.factor * \
237
+ (self.model_size ** (-0.5) *
238
+ min(step ** (-0.5), step * self.warmup ** (-1.5)))
239
+
240
+ def __getattr__(self, name):
241
+ return getattr(self.optimizer, name)
242
+
243
+ def get_std_opt(model, optim_func='adam', factor=1, warmup=2000):
244
+ # return NoamOpt(model.tgt_embed[0].d_model, 2, 4000,
245
+ # torch.optim.Adam(model.parameters(), lr=0, betas=(0.9, 0.98), eps=1e-9))
246
+ optim_func = dict(adam=torch.optim.Adam,
247
+ adamw=torch.optim.AdamW)[optim_func]
248
+ return NoamOpt(model.d_model, factor, warmup,
249
+ optim_func(model.parameters(), lr=0, betas=(0.9, 0.98), eps=1e-9))
captioning/utils/opts.py ADDED
@@ -0,0 +1,365 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import print_function
2
+ import argparse
3
+
4
+
5
+ def if_use_feat(caption_model):
6
+ # Decide if load attention feature according to caption model
7
+ if caption_model in ['show_tell', 'all_img', 'fc', 'newfc']:
8
+ use_att, use_fc = False, True
9
+ elif caption_model == 'language_model':
10
+ use_att, use_fc = False, False
11
+ elif caption_model in ['updown', 'topdown']:
12
+ use_fc, use_att = True, True
13
+ else:
14
+ use_att, use_fc = True, False
15
+ return use_fc, use_att
16
+
17
+
18
+ def parse_opt():
19
+ parser = argparse.ArgumentParser()
20
+ # Data input settings
21
+ parser.add_argument('--input_json', type=str, default='data/coco.json',
22
+ help='path to the json file containing additional info and vocab')
23
+ parser.add_argument('--input_fc_dir', type=str, default='data/cocotalk_fc',
24
+ help='path to the directory containing the preprocessed fc feats')
25
+ parser.add_argument('--input_att_dir', type=str, default='data/cocotalk_att',
26
+ help='path to the directory containing the preprocessed att feats')
27
+ parser.add_argument('--input_box_dir', type=str, default='data/cocotalk_box',
28
+ help='path to the directory containing the boxes of att feats')
29
+ parser.add_argument('--input_label_h5', type=str, default='data/coco_label.h5',
30
+ help='path to the h5file containing the preprocessed dataset')
31
+ parser.add_argument('--data_in_memory', action='store_true',
32
+ help='True if we want to save the features in memory')
33
+ parser.add_argument('--start_from', type=str, default=None,
34
+ help="""continue training from saved model at this path. Path must contain files saved by previous training process:
35
+ 'infos.pkl' : configuration;
36
+ 'model.pth' : weights
37
+ """)
38
+ parser.add_argument('--cached_tokens', type=str, default='coco-train-idxs',
39
+ help='Cached token file for calculating cider score during self critical training.')
40
+
41
+ # Model settings
42
+ parser.add_argument('--caption_model', type=str, default="show_tell",
43
+ help='show_tell, show_attend_tell, all_img, fc, att2in, att2in2, att2all2, adaatt, adaattmo, updown, stackatt, denseatt, transformer')
44
+ parser.add_argument('--rnn_size', type=int, default=512,
45
+ help='size of the rnn in number of hidden nodes in each layer')
46
+ parser.add_argument('--num_layers', type=int, default=1,
47
+ help='number of layers in the RNN')
48
+ parser.add_argument('--rnn_type', type=str, default='lstm',
49
+ help='rnn, gru, or lstm')
50
+ parser.add_argument('--input_encoding_size', type=int, default=512,
51
+ help='the encoding size of each token in the vocabulary, and the image.')
52
+ parser.add_argument('--att_hid_size', type=int, default=512,
53
+ help='the hidden size of the attention MLP; only useful in show_attend_tell; 0 if not using hidden layer')
54
+ parser.add_argument('--fc_feat_size', type=int, default=2048,
55
+ help='2048 for resnet, 4096 for vgg')
56
+ parser.add_argument('--att_feat_size', type=int, default=2048,
57
+ help='2048 for resnet, 512 for vgg')
58
+ parser.add_argument('--logit_layers', type=int, default=1,
59
+ help='number of layers in the RNN')
60
+
61
+
62
+ parser.add_argument('--use_bn', type=int, default=0,
63
+ help='If 1, then do batch_normalization first in att_embed, if 2 then do bn both in the beginning and the end of att_embed')
64
+
65
+ # feature manipulation
66
+ parser.add_argument('--norm_att_feat', type=int, default=0,
67
+ help='If normalize attention features')
68
+ parser.add_argument('--use_box', type=int, default=0,
69
+ help='If use box features')
70
+ parser.add_argument('--norm_box_feat', type=int, default=0,
71
+ help='If use box, do we normalize box feature')
72
+
73
+ # Optimization: General
74
+ parser.add_argument('--max_epochs', type=int, default=-1,
75
+ help='number of epochs')
76
+ parser.add_argument('--batch_size', type=int, default=16,
77
+ help='minibatch size')
78
+ parser.add_argument('--grad_clip_mode', type=str, default='value',
79
+ help='value or norm')
80
+ parser.add_argument('--grad_clip_value', type=float, default=0.1,
81
+ help='clip gradients at this value/max_norm, 0 means no clipping')
82
+ parser.add_argument('--drop_prob_lm', type=float, default=0.5,
83
+ help='strength of dropout in the Language Model RNN')
84
+ parser.add_argument('--self_critical_after', type=int, default=-1,
85
+ help='After what epoch do we start finetuning the CNN? (-1 = disable; never finetune, 0 = finetune from start)')
86
+ parser.add_argument('--seq_per_img', type=int, default=5,
87
+ help='number of captions to sample for each image during training. Done for efficiency since CNN forward pass is expensive. E.g. coco has 5 sents/image')
88
+
89
+ # Sample related
90
+ add_eval_sample_opts(parser)
91
+
92
+ #Optimization: for the Language Model
93
+ parser.add_argument('--optim', type=str, default='adam',
94
+ help='what update to use? rmsprop|sgd|sgdmom|adagrad|adam|adamw')
95
+ parser.add_argument('--learning_rate', type=float, default=4e-4,
96
+ help='learning rate')
97
+ parser.add_argument('--learning_rate_decay_start', type=int, default=-1,
98
+ help='at what iteration to start decaying learning rate? (-1 = dont) (in epoch)')
99
+ parser.add_argument('--learning_rate_decay_every', type=int, default=3,
100
+ help='every how many iterations thereafter to drop LR?(in epoch)')
101
+ parser.add_argument('--learning_rate_decay_rate', type=float, default=0.8,
102
+ help='every how many iterations thereafter to drop LR?(in epoch)')
103
+ parser.add_argument('--optim_alpha', type=float, default=0.9,
104
+ help='alpha for adam')
105
+ parser.add_argument('--optim_beta', type=float, default=0.999,
106
+ help='beta used for adam')
107
+ parser.add_argument('--optim_epsilon', type=float, default=1e-8,
108
+ help='epsilon that goes into denominator for smoothing')
109
+ parser.add_argument('--weight_decay', type=float, default=0,
110
+ help='weight_decay')
111
+ # Transformer
112
+ parser.add_argument('--label_smoothing', type=float, default=0,
113
+ help='')
114
+ parser.add_argument('--noamopt', action='store_true',
115
+ help='')
116
+ parser.add_argument('--noamopt_warmup', type=int, default=2000,
117
+ help='')
118
+ parser.add_argument('--noamopt_factor', type=float, default=1,
119
+ help='')
120
+ parser.add_argument('--reduce_on_plateau', action='store_true',
121
+ help='')
122
+ parser.add_argument('--reduce_on_plateau_factor', type=float, default=0.5,
123
+ help='')
124
+ parser.add_argument('--reduce_on_plateau_patience', type=int, default=3,
125
+ help='')
126
+ parser.add_argument('--cached_transformer', action='store_true',
127
+ help='')
128
+
129
+
130
+ parser.add_argument('--use_warmup', action='store_true',
131
+ help='warm up the learing rate?')
132
+
133
+ parser.add_argument('--scheduled_sampling_start', type=int, default=-1,
134
+ help='at what iteration to start decay gt probability')
135
+ parser.add_argument('--scheduled_sampling_increase_every', type=int, default=5,
136
+ help='every how many iterations thereafter to gt probability')
137
+ parser.add_argument('--scheduled_sampling_increase_prob', type=float, default=0.05,
138
+ help='How much to update the prob')
139
+ parser.add_argument('--scheduled_sampling_max_prob', type=float, default=0.25,
140
+ help='Maximum scheduled sampling prob.')
141
+
142
+
143
+ # Evaluation/Checkpointing
144
+ parser.add_argument('--val_images_use', type=int, default=3200,
145
+ help='how many images to use when periodically evaluating the validation loss? (-1 = all)')
146
+ parser.add_argument('--save_checkpoint_every', type=int, default=2500,
147
+ help='how often to save a model checkpoint (in iterations)?')
148
+ parser.add_argument('--save_every_epoch', action='store_true',
149
+ help='Save checkpoint every epoch, will overwrite save_checkpoint_every')
150
+ parser.add_argument('--save_history_ckpt', type=int, default=0,
151
+ help='If save checkpoints at every save point')
152
+ parser.add_argument('--checkpoint_path', type=str, default=None,
153
+ help='directory to store checkpointed models')
154
+ parser.add_argument('--language_eval', type=int, default=0,
155
+ help='Evaluate language as well (1 = yes, 0 = no)? BLEU/CIDEr/METEOR/ROUGE_L? requires coco-caption code from Github.')
156
+ parser.add_argument('--losses_log_every', type=int, default=25,
157
+ help='How often do we snapshot losses, for inclusion in the progress dump? (0 = disable)')
158
+ parser.add_argument('--load_best_score', type=int, default=1,
159
+ help='Do we load previous best score when resuming training.')
160
+
161
+ # misc
162
+ parser.add_argument('--id', type=str, default='',
163
+ help='an id identifying this run/job. used in cross-val and appended when writing progress files')
164
+ parser.add_argument('--train_only', type=int, default=0,
165
+ help='if true then use 80k, else use 110k')
166
+ parser.add_argument('--topic', type=str, default='dress',
167
+ help='type of datasets, such as dress, shirt, toptee')
168
+
169
+
170
+ # Reward
171
+ parser.add_argument('--cider_reward_weight', type=float, default=1,
172
+ help='The reward weight from cider')
173
+ parser.add_argument('--bleu_reward_weight', type=float, default=0,
174
+ help='The reward weight from bleu4')
175
+
176
+
177
+ # Structure_loss
178
+ parser.add_argument('--structure_loss_weight', type=float, default=1,
179
+ help='')
180
+ parser.add_argument('--structure_after', type=int, default=-1,
181
+ help='T')
182
+ parser.add_argument('--structure_loss_type', type=str, default='seqnll',
183
+ help='')
184
+ parser.add_argument('--struc_use_logsoftmax', action='store_true', help='')
185
+ parser.add_argument('--entropy_reward_weight', type=float, default=0,
186
+ help='Entropy reward, seems very interesting')
187
+ parser.add_argument('--self_cider_reward_weight', type=float, default=0,
188
+ help='self cider reward')
189
+
190
+ # Used for self critical or structure. Used when sampling is need during training
191
+ parser.add_argument('--train_sample_n', type=int, default=1,
192
+ help='The reward weight from cider')
193
+ parser.add_argument('--train_sample_method', type=str, default='sample',
194
+ help='')
195
+ parser.add_argument('--train_beam_size', type=int, default=1,
196
+ help='')
197
+
198
+ # Used for self critical
199
+ parser.add_argument('--sc_sample_method', type=str, default='greedy',
200
+ help='')
201
+ parser.add_argument('--sc_beam_size', type=int, default=1,
202
+ help='')
203
+
204
+ parser.add_argument('--seed', type=int, default=42,
205
+ help='')
206
+
207
+ # For diversity evaluation during training
208
+ add_diversity_opts(parser)
209
+
210
+
211
+ # config
212
+ parser.add_argument('--cfg', type=str, default=None,
213
+ help='configuration; similar to what is used in detectron')
214
+ parser.add_argument(
215
+ '--set_cfgs', dest='set_cfgs',
216
+ help='Set config keys. Key value sequence seperate by whitespace.'
217
+ 'e.g. [key] [value] [key] [value]\n This has higher priority'
218
+ 'than cfg file but lower than other args. (You can only overwrite'
219
+ 'arguments that have alerady been defined in config file.)',
220
+ default=[], nargs='+')
221
+ # How will config be used
222
+ # 1) read cfg argument, and load the cfg file if it's not None
223
+ # 2) Overwrite cfg argument with set_cfgs
224
+ # 3) parse config argument to args.
225
+ # 4) in the end, parse command line argument and overwrite args
226
+
227
+ # step 1: read cfg_fn
228
+ args = parser.parse_args()
229
+ if args.cfg is not None or args.set_cfgs is not None:
230
+ from .config import CfgNode
231
+ if args.cfg is not None:
232
+ cn = CfgNode(CfgNode.load_yaml_with_base(args.cfg))
233
+ else:
234
+ cn = CfgNode()
235
+ if args.set_cfgs is not None:
236
+ cn.merge_from_list(args.set_cfgs)
237
+ for k,v in cn.items():
238
+ if not hasattr(args, k):
239
+ print('Warning: key %s not in args' %k)
240
+ setattr(args, k, v)
241
+ args = parser.parse_args(namespace=args)
242
+
243
+ # Check if args are valid
244
+ assert args.rnn_size > 0, "rnn_size should be greater than 0"
245
+ assert args.num_layers > 0, "num_layers should be greater than 0"
246
+ assert args.input_encoding_size > 0, "input_encoding_size should be greater than 0"
247
+ assert args.batch_size > 0, "batch_size should be greater than 0"
248
+ assert args.drop_prob_lm >= 0 and args.drop_prob_lm < 1, "drop_prob_lm should be between 0 and 1"
249
+ assert args.seq_per_img > 0, "seq_per_img should be greater than 0"
250
+ assert args.beam_size > 0, "beam_size should be greater than 0"
251
+ assert args.save_checkpoint_every > 0, "save_checkpoint_every should be greater than 0"
252
+ assert args.losses_log_every > 0, "losses_log_every should be greater than 0"
253
+ assert args.language_eval == 0 or args.language_eval == 1, "language_eval should be 0 or 1"
254
+ assert args.load_best_score == 0 or args.load_best_score == 1, "language_eval should be 0 or 1"
255
+ assert args.train_only == 0 or args.train_only == 1, "language_eval should be 0 or 1"
256
+
257
+ # default value for start_from and checkpoint_path
258
+ # args.checkpoint_path = args.checkpoint_path or './log_%s' %args.id
259
+ args.checkpoint_path = args.checkpoint_path or './results/log_{}_{}'.format(args.topic, args.id)
260
+ args.start_from = args.start_from or args.checkpoint_path
261
+
262
+ # Deal with feature things before anything
263
+ args.use_fc, args.use_att = if_use_feat(args.caption_model)
264
+ if args.use_box: args.att_feat_size = args.att_feat_size + 5
265
+
266
+ return args
267
+
268
+
269
+ def add_eval_options(parser):
270
+ # Basic options
271
+ parser.add_argument('--batch_size', type=int, default=0,
272
+ help='if > 0 then overrule, otherwise load from checkpoint.')
273
+ parser.add_argument('--num_images', type=int, default=-1,
274
+ help='how many images to use when periodically evaluating the loss? (-1 = all)')
275
+ parser.add_argument('--language_eval', type=int, default=0,
276
+ help='Evaluate language as well (1 = yes, 0 = no)? BLEU/CIDEr/METEOR/ROUGE_L? requires coco-caption code from Github.')
277
+ parser.add_argument('--dump_images', type=int, default=1,
278
+ help='Dump images into vis/imgs folder for vis? (1=yes,0=no)')
279
+ parser.add_argument('--dump_json', type=int, default=1,
280
+ help='Dump json with predictions into vis folder? (1=yes,0=no)')
281
+ parser.add_argument('--dump_path', type=int, default=0,
282
+ help='Write image paths along with predictions into vis json? (1=yes,0=no)')
283
+
284
+ # Sampling options
285
+ add_eval_sample_opts(parser)
286
+
287
+ # For evaluation on a folder of images:
288
+ parser.add_argument('--image_folder', type=str, default='',
289
+ help='If this is nonempty then will predict on the images in this folder path')
290
+ parser.add_argument('--image_root', type=str, default='',
291
+ help='In case the image paths have to be preprended with a root path to an image folder')
292
+ # For evaluation on MSCOCO images from some split:
293
+ parser.add_argument('--input_fc_dir', type=str, default='',
294
+ help='path to the h5file containing the preprocessed dataset')
295
+ parser.add_argument('--input_att_dir', type=str, default='',
296
+ help='path to the h5file containing the preprocessed dataset')
297
+ parser.add_argument('--input_box_dir', type=str, default='',
298
+ help='path to the h5file containing the preprocessed dataset')
299
+ parser.add_argument('--input_label_h5', type=str, default='',
300
+ help='path to the h5file containing the preprocessed dataset')
301
+ parser.add_argument('--input_json', type=str, default='',
302
+ help='path to the json file containing additional info and vocab. empty = fetch from model checkpoint.')
303
+ parser.add_argument('--split', type=str, default='test',
304
+ help='if running on MSCOCO images, which split to use: val|test|train')
305
+ parser.add_argument('--coco_json', type=str, default='',
306
+ help='if nonempty then use this file in DataLoaderRaw (see docs there). Used only in MSCOCO test evaluation, where we have a specific json file of only test set images.')
307
+ # misc
308
+ parser.add_argument('--id', type=str, default='',
309
+ help='an id identifying this run/job. used only if language_eval = 1 for appending to intermediate files')
310
+ parser.add_argument('--verbose_beam', type=int, default=1,
311
+ help='if we need to print out all beam search beams.')
312
+ parser.add_argument('--verbose_loss', type=int, default=0,
313
+ help='If calculate loss using ground truth during evaluation')
314
+
315
+ parser.add_argument('--seed', type=int, default=42,
316
+ help='')
317
+
318
+ def add_diversity_opts(parser):
319
+ parser.add_argument('--sample_n', type=int, default=1,
320
+ help='Diverse sampling')
321
+ parser.add_argument('--sample_n_method', type=str, default='sample',
322
+ help='sample, bs, dbs, gumbel, topk, dgreedy, dsample, dtopk, dtopp')
323
+ parser.add_argument('--eval_oracle', type=int, default=1,
324
+ help='if we need to calculate loss.')
325
+
326
+
327
+ # Sampling related options
328
+ def add_eval_sample_opts(parser):
329
+ parser.add_argument('--sample_method', type=str, default='greedy',
330
+ help='greedy; sample; gumbel; top<int>, top<0-1>')
331
+ parser.add_argument('--beam_size', type=int, default=1,
332
+ help='used when sample_method = greedy, indicates number of beams in beam search. Usually 2 or 3 works well. More is not better. Set this to 1 for faster runtime but a bit worse performance.')
333
+ parser.add_argument('--max_length', type=int, default=8,
334
+ help='Maximum length during sampling')
335
+ parser.add_argument('--length_penalty', type=str, default='',
336
+ help='wu_X or avg_X, X is the alpha')
337
+ parser.add_argument('--group_size', type=int, default=1,
338
+ help='used for diverse beam search. if group_size is 1, then it\'s normal beam search')
339
+ parser.add_argument('--diversity_lambda', type=float, default=0.5,
340
+ help='used for diverse beam search. Usually from 0.2 to 0.8. Higher value of lambda produces a more diverse list')
341
+ parser.add_argument('--temperature', type=float, default=1.0,
342
+ help='temperature when sampling from distributions (i.e. when sample_method = sample). Lower = "safer" predictions.')
343
+ parser.add_argument('--decoding_constraint', type=int, default=0,
344
+ help='If 1, not allowing same word in a row')
345
+ parser.add_argument('--block_trigrams', type=int, default=0,
346
+ help='block repeated trigram.')
347
+ parser.add_argument('--remove_bad_endings', type=int, default=1,
348
+ help='Remove bad endings')
349
+ parser.add_argument('--suppress_UNK', type=int, default=1,
350
+ help='Not predicting UNK')
351
+
352
+
353
+ if __name__ == '__main__':
354
+ import sys
355
+ sys.argv = [sys.argv[0]]
356
+ args = parse_opt()
357
+ print(args)
358
+ print()
359
+ sys.argv = [sys.argv[0], '--cfg', 'configs/updown_long.yml']
360
+ args1 = parse_opt()
361
+ print(dict(set(vars(args1).items()) - set(vars(args).items())))
362
+ print()
363
+ sys.argv = [sys.argv[0], '--cfg', 'configs/updown_long.yml', '--caption_model', 'att2in2']
364
+ args2 = parse_opt()
365
+ print(dict(set(vars(args2).items()) - set(vars(args1).items())))
captioning/utils/resnet.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torchvision.models.resnet
4
+ from torchvision.models.resnet import BasicBlock, Bottleneck
5
+ import torch.utils.model_zoo as model_zoo
6
+
7
+ __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101',
8
+ 'resnet152']
9
+
10
+ model_urls = {
11
+ 'resnet18': 'https://download.pytorch.org/models/resnet18-f37072fd.pth',
12
+ 'resnet34': 'https://download.pytorch.org/models/resnet34-b627a593.pth',
13
+ 'resnet50': 'https://download.pytorch.org/models/resnet50-11ad3fa6.pth',
14
+ 'resnet101': 'https://download.pytorch.org/models/resnet101-cd907fc2.pth',
15
+ 'resnet152': 'https://download.pytorch.org/models/resnet152-f82ba261.pth',
16
+ }
17
+
18
+ class ResNet(torchvision.models.resnet.ResNet):
19
+ def __init__(self, block, layers, num_classes=1000):
20
+ super(ResNet, self).__init__(block, layers, num_classes)
21
+ self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=0, ceil_mode=True) # change
22
+ for i in range(2, 5):
23
+ getattr(self, 'layer%d'%i)[0].conv1.stride = (2,2)
24
+ getattr(self, 'layer%d'%i)[0].conv2.stride = (1,1)
25
+
26
+ def resnet18(pretrained=False):
27
+ """Constructs a ResNet-18 model.
28
+
29
+ Args:
30
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
31
+ """
32
+ model = ResNet(BasicBlock, [2, 2, 2, 2])
33
+ if pretrained:
34
+ model.load_state_dict(model_zoo.load_url(model_urls['resnet18']))
35
+ return model
36
+
37
+
38
+ def resnet34(pretrained=False):
39
+ """Constructs a ResNet-34 model.
40
+
41
+ Args:
42
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
43
+ """
44
+ model = ResNet(BasicBlock, [3, 4, 6, 3])
45
+ if pretrained:
46
+ model.load_state_dict(model_zoo.load_url(model_urls['resnet34']))
47
+ return model
48
+
49
+
50
+ def resnet50(pretrained=False):
51
+ """Constructs a ResNet-50 model.
52
+
53
+ Args:
54
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
55
+ """
56
+ model = ResNet(Bottleneck, [3, 4, 6, 3])
57
+ if pretrained:
58
+ model.load_state_dict(model_zoo.load_url(model_urls['resnet50']))
59
+ return model
60
+
61
+
62
+ def resnet101(pretrained=False):
63
+ """Constructs a ResNet-101 model.
64
+
65
+ Args:
66
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
67
+ """
68
+
69
+ model = ResNet(Bottleneck, [3, 4, 23, 3])
70
+ if pretrained:
71
+ model.load_state_dict(model_zoo.load_url(model_urls['resnet101']))
72
+ return model
73
+
74
+
75
+ def resnet152(pretrained=False):
76
+ """Constructs a ResNet-152 model.
77
+
78
+ Args:
79
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
80
+ """
81
+ model = ResNet(Bottleneck, [3, 8, 36, 3])
82
+ if pretrained:
83
+ model.load_state_dict(model_zoo.load_url(model_urls['resnet152']))
84
+ return model
captioning/utils/resnet_utils.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ class myResnet(nn.Module):
6
+ def __init__(self, resnet):
7
+ super(myResnet, self).__init__()
8
+ self.resnet = resnet
9
+
10
+ def forward(self, img, att_size=14):
11
+ x = img.unsqueeze(0)
12
+
13
+ x = self.resnet.conv1(x)
14
+ x = self.resnet.bn1(x)
15
+ x = self.resnet.relu(x)
16
+ x = self.resnet.maxpool(x)
17
+
18
+ x = self.resnet.layer1(x)
19
+ x = self.resnet.layer2(x)
20
+ x = self.resnet.layer3(x)
21
+ x = self.resnet.layer4(x)
22
+
23
+ fc = x.mean(3).mean(2).squeeze()
24
+ att = F.adaptive_avg_pool2d(x,[att_size,att_size]).squeeze().permute(1, 2, 0)
25
+
26
+ return fc, att
27
+
28
+
29
+ class ResNetBatch(nn.Module):
30
+ def __init__(self, resnet):
31
+ super(ResNetBatch, self).__init__()
32
+ self.resnet = resnet
33
+
34
+ def forward(self, x, att_size=14):
35
+ # size of x: nimages x nChannel x dim x dim
36
+
37
+ x = self.resnet.conv1(x)
38
+ x = self.resnet.bn1(x)
39
+ x = self.resnet.relu(x)
40
+ x = self.resnet.maxpool(x)
41
+
42
+ x = self.resnet.layer1(x)
43
+ x = self.resnet.layer2(x)
44
+ x = self.resnet.layer3(x)
45
+ x = self.resnet.layer4(x)
46
+
47
+ fc = x.mean(3).mean(2)
48
+ # att = F.adaptive_avg_pool2d(x, [att_size, att_size]).squeeze().permute(1, 2, 0)
49
+ att = F.adaptive_avg_pool2d(x, [att_size, att_size]).permute(0, 2, 3, 1)
50
+
51
+ return fc, att
captioning/utils/rewards.py ADDED
@@ -0,0 +1,136 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import absolute_import
2
+ from __future__ import division
3
+ from __future__ import print_function
4
+
5
+ import numpy as np
6
+ import time
7
+ from collections import OrderedDict
8
+ import torch
9
+
10
+ import sys
11
+ try:
12
+ sys.path.append("cider")
13
+ from pyciderevalcap.ciderD.ciderD import CiderD
14
+ from pyciderevalcap.cider.cider import Cider
15
+ sys.path.append("coco-caption")
16
+ from pycocoevalcap.bleu.bleu import Bleu
17
+ except:
18
+ print('cider or coco-caption missing')
19
+
20
+ CiderD_scorer = None
21
+ Cider_scorer = None
22
+ Bleu_scorer = None
23
+ #CiderD_scorer = CiderD(df='corpus')
24
+
25
+ def init_scorer(cached_tokens):
26
+ global CiderD_scorer
27
+ CiderD_scorer = CiderD_scorer or CiderD(df=cached_tokens)
28
+ global Cider_scorer
29
+ Cider_scorer = Cider_scorer or Cider(df=cached_tokens)
30
+ global Bleu_scorer
31
+ Bleu_scorer = Bleu_scorer or Bleu(4)
32
+
33
+ def array_to_str(arr):
34
+ out = ''
35
+ for i in range(len(arr)):
36
+ out += str(arr[i]) + ' '
37
+ if arr[i] == 0:
38
+ break
39
+ return out.strip()
40
+
41
+ def get_self_critical_reward(greedy_res, data_gts, gen_result, opt):
42
+ batch_size = len(data_gts)
43
+ gen_result_size = gen_result.shape[0]
44
+ seq_per_img = gen_result_size // len(data_gts) # gen_result_size = batch_size * seq_per_img
45
+ assert greedy_res.shape[0] == batch_size
46
+
47
+ res = OrderedDict()
48
+ gen_result = gen_result.data.cpu().numpy()
49
+ greedy_res = greedy_res.data.cpu().numpy()
50
+ for i in range(gen_result_size):
51
+ res[i] = [array_to_str(gen_result[i])]
52
+ for i in range(batch_size):
53
+ res[gen_result_size + i] = [array_to_str(greedy_res[i])]
54
+
55
+ gts = OrderedDict()
56
+ for i in range(len(data_gts)):
57
+ gts[i] = [array_to_str(data_gts[i][j]) for j in range(len(data_gts[i]))]
58
+
59
+ res_ = [{'image_id':i, 'caption': res[i]} for i in range(len(res))]
60
+ res__ = {i: res[i] for i in range(len(res_))}
61
+ gts_ = {i: gts[i // seq_per_img] for i in range(gen_result_size)}
62
+ gts_.update({i+gen_result_size: gts[i] for i in range(batch_size)})
63
+ if opt.cider_reward_weight > 0:
64
+ _, cider_scores = CiderD_scorer.compute_score(gts_, res_)
65
+ print('Cider scores:', _)
66
+ else:
67
+ cider_scores = 0
68
+ if opt.bleu_reward_weight > 0:
69
+ _, bleu_scores = Bleu_scorer.compute_score(gts_, res__)
70
+ bleu_scores = np.array(bleu_scores[3])
71
+ print('Bleu scores:', _[3])
72
+ else:
73
+ bleu_scores = 0
74
+ scores = opt.cider_reward_weight * cider_scores + opt.bleu_reward_weight * bleu_scores
75
+
76
+ scores = scores[:gen_result_size].reshape(batch_size, seq_per_img) - scores[-batch_size:][:, np.newaxis]
77
+ scores = scores.reshape(gen_result_size)
78
+
79
+ rewards = np.repeat(scores[:, np.newaxis], gen_result.shape[1], 1)
80
+
81
+ return rewards
82
+
83
+ def get_scores(data_gts, gen_result, opt):
84
+ batch_size = gen_result.size(0)# batch_size = sample_size * seq_per_img
85
+ seq_per_img = batch_size // len(data_gts)
86
+
87
+ res = OrderedDict()
88
+
89
+ gen_result = gen_result.data.cpu().numpy()
90
+ for i in range(batch_size):
91
+ res[i] = [array_to_str(gen_result[i])]
92
+
93
+ gts = OrderedDict()
94
+ for i in range(len(data_gts)):
95
+ gts[i] = [array_to_str(data_gts[i][j]) for j in range(len(data_gts[i]))]
96
+
97
+ res_ = [{'image_id':i, 'caption': res[i]} for i in range(batch_size)]
98
+ res__ = {i: res[i] for i in range(batch_size)}
99
+ gts = {i: gts[i // seq_per_img] for i in range(batch_size)}
100
+ if opt.cider_reward_weight > 0:
101
+ _, cider_scores = CiderD_scorer.compute_score(gts, res_)
102
+ print('Cider scores:', _)
103
+ else:
104
+ cider_scores = 0
105
+ if opt.bleu_reward_weight > 0:
106
+ _, bleu_scores = Bleu_scorer.compute_score(gts, res__)
107
+ bleu_scores = np.array(bleu_scores[3])
108
+ print('Bleu scores:', _[3])
109
+ else:
110
+ bleu_scores = 0
111
+
112
+ scores = opt.cider_reward_weight * cider_scores + opt.bleu_reward_weight * bleu_scores
113
+
114
+ return scores
115
+
116
+ def get_self_cider_scores(data_gts, gen_result, opt):
117
+ batch_size = gen_result.size(0)# batch_size = sample_size * seq_per_img
118
+ seq_per_img = batch_size // len(data_gts)
119
+
120
+ res = []
121
+
122
+ gen_result = gen_result.data.cpu().numpy()
123
+ for i in range(batch_size):
124
+ res.append(array_to_str(gen_result[i]))
125
+
126
+ scores = []
127
+ for i in range(len(data_gts)):
128
+ tmp = Cider_scorer.my_self_cider([res[i*seq_per_img:(i+1)*seq_per_img]])
129
+ def get_div(eigvals):
130
+ eigvals = np.clip(eigvals, 0, None)
131
+ return -np.log(np.sqrt(eigvals[-1]) / (np.sqrt(eigvals).sum())) / np.log(len(eigvals))
132
+ scores.append(get_div(np.linalg.eigvalsh(tmp[0]/10)))
133
+
134
+ scores = np.array(scores)
135
+
136
+ return scores