diff --git a/app.py b/app.py index 6b1dd2c32cd02024cf3d7af36b17d7de334e06dd..2feedda9976a7437752e65580531ec4138f33178 100644 --- a/app.py +++ b/app.py @@ -1,194 +1,198 @@ -import torch -import torch.nn as nn -import numpy as np -import json -import captioning.utils.opts as opts -import captioning.models as models -import captioning.utils.misc as utils -import pytorch_lightning as pl +# import torch +# import torch.nn as nn +# import numpy as np +# import json +# import captioning.utils.opts as opts +# import captioning.models as models +# import captioning.utils.misc as utils +# import pytorch_lightning as pl import gradio as gr -from diffusers import LDMTextToImagePipeline -# import PIL.Image +# from diffusers import LDMTextToImagePipeline +# # import PIL.Image import random -import os +# import os -# Checkpoint class -class ModelCheckpoint(pl.callbacks.ModelCheckpoint): - def on_keyboard_interrupt(self, trainer, pl_module): - # Save model when keyboard interrupt - filepath = os.path.join(self.dirpath, self.prefix + 'interrupt.ckpt') - self._save_model(filepath) +# # Checkpoint class +# class ModelCheckpoint(pl.callbacks.ModelCheckpoint): +# def on_keyboard_interrupt(self, trainer, pl_module): +# # Save model when keyboard interrupt +# filepath = os.path.join(self.dirpath, self.prefix + 'interrupt.ckpt') +# self._save_model(filepath) -device = 'cpu' #@param ["cuda", "cpu"] {allow-input: true} -reward = 'clips_grammar' +# device = 'cpu' #@param ["cuda", "cpu"] {allow-input: true} +# reward = 'clips_grammar' -cfg = f'./configs/phase2/clipRN50_{reward}.yml' +# cfg = f'./configs/phase2/clipRN50_{reward}.yml' -print("Loading cfg from", cfg) +# print("Loading cfg from", cfg) -opt = opts.parse_opt(parse=False, cfg=cfg) +# opt = opts.parse_opt(parse=False, cfg=cfg) -import gdown +# import gdown -url = "https://drive.google.com/drive/folders/1nSX9aS7pPK4-OTHYtsUD_uEkwIQVIV7W" -gdown.download_folder(url, quiet=True, use_cookies=False, output="save/") +# url = "https://drive.google.com/drive/folders/1nSX9aS7pPK4-OTHYtsUD_uEkwIQVIV7W" +# gdown.download_folder(url, quiet=True, use_cookies=False, output="save/") -url = "https://drive.google.com/uc?id=1HNRE1MYO9wxmtMHLC8zURraoNFu157Dp" -gdown.download(url, quiet=True, use_cookies=False, output="data/") +# url = "https://drive.google.com/uc?id=1HNRE1MYO9wxmtMHLC8zURraoNFu157Dp" +# gdown.download(url, quiet=True, use_cookies=False, output="data/") -dict_json = json.load(open('./data/cocotalk.json')) -print(dict_json.keys()) +# dict_json = json.load(open('./data/cocotalk.json')) +# print(dict_json.keys()) -ix_to_word = dict_json['ix_to_word'] -vocab_size = len(ix_to_word) -print('vocab size:', vocab_size) +# ix_to_word = dict_json['ix_to_word'] +# vocab_size = len(ix_to_word) +# print('vocab size:', vocab_size) -seq_length = 1 +# seq_length = 1 -opt.vocab_size = vocab_size -opt.seq_length = seq_length +# opt.vocab_size = vocab_size +# opt.seq_length = seq_length -opt.batch_size = 1 -opt.vocab = ix_to_word +# opt.batch_size = 1 +# opt.vocab = ix_to_word -model = models.setup(opt) -del opt.vocab +# model = models.setup(opt) +# del opt.vocab -ckpt_path = opt.checkpoint_path + '-last.ckpt' +# ckpt_path = opt.checkpoint_path + '-last.ckpt' -print("Loading checkpoint from", ckpt_path) -raw_state_dict = torch.load( - ckpt_path, - map_location=device) +# print("Loading checkpoint from", ckpt_path) +# raw_state_dict = torch.load( +# ckpt_path, +# map_location=device) -strict = True +# strict = True -state_dict = raw_state_dict['state_dict'] +# state_dict = raw_state_dict['state_dict'] -if '_vocab' in state_dict: - model.vocab = utils.deserialize(state_dict['_vocab']) - del state_dict['_vocab'] -elif strict: - raise KeyError -if '_opt' in state_dict: - saved_model_opt = utils.deserialize(state_dict['_opt']) - del state_dict['_opt'] - # Make sure the saved opt is compatible with the curren topt - need_be_same = ["caption_model", - "rnn_type", "rnn_size", "num_layers"] - for checkme in need_be_same: - if getattr(saved_model_opt, checkme) in ['updown', 'topdown'] and \ - getattr(opt, checkme) in ['updown', 'topdown']: - continue - assert getattr(saved_model_opt, checkme) == getattr( - opt, checkme), "Command line argument and saved model disagree on '%s' " % checkme -elif strict: - raise KeyError -res = model.load_state_dict(state_dict, strict) -print(res) +# if '_vocab' in state_dict: +# model.vocab = utils.deserialize(state_dict['_vocab']) +# del state_dict['_vocab'] +# elif strict: +# raise KeyError +# if '_opt' in state_dict: +# saved_model_opt = utils.deserialize(state_dict['_opt']) +# del state_dict['_opt'] +# # Make sure the saved opt is compatible with the curren topt +# need_be_same = ["caption_model", +# "rnn_type", "rnn_size", "num_layers"] +# for checkme in need_be_same: +# if getattr(saved_model_opt, checkme) in ['updown', 'topdown'] and \ +# getattr(opt, checkme) in ['updown', 'topdown']: +# continue +# assert getattr(saved_model_opt, checkme) == getattr( +# opt, checkme), "Command line argument and saved model disagree on '%s' " % checkme +# elif strict: +# raise KeyError +# res = model.load_state_dict(state_dict, strict) +# print(res) -model = model.to(device) -model.eval(); +# model = model.to(device) +# model.eval(); -import clip -from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize -from PIL import Image -from timm.models.vision_transformer import resize_pos_embed +# import clip +# from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize +# from PIL import Image +# from timm.models.vision_transformer import resize_pos_embed -clip_model, clip_transform = clip.load("RN50", jit=False, device=device) +# clip_model, clip_transform = clip.load("RN50", jit=False, device=device) -preprocess = Compose([ - Resize((448, 448), interpolation=Image.BICUBIC), - CenterCrop((448, 448)), - ToTensor() -]) +# preprocess = Compose([ +# Resize((448, 448), interpolation=Image.BICUBIC), +# CenterCrop((448, 448)), +# ToTensor() +# ]) -image_mean = torch.Tensor([0.48145466, 0.4578275, 0.40821073]).to(device).reshape(3, 1, 1) -image_std = torch.Tensor([0.26862954, 0.26130258, 0.27577711]).to(device).reshape(3, 1, 1) +# image_mean = torch.Tensor([0.48145466, 0.4578275, 0.40821073]).to(device).reshape(3, 1, 1) +# image_std = torch.Tensor([0.26862954, 0.26130258, 0.27577711]).to(device).reshape(3, 1, 1) -num_patches = 196 #600 * 1000 // 32 // 32 -pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, clip_model.visual.attnpool.positional_embedding.shape[-1], device=device),) -pos_embed.weight = resize_pos_embed(clip_model.visual.attnpool.positional_embedding.unsqueeze(0), pos_embed) -clip_model.visual.attnpool.positional_embedding = pos_embed +# num_patches = 196 #600 * 1000 // 32 // 32 +# pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, clip_model.visual.attnpool.positional_embedding.shape[-1], device=device),) +# pos_embed.weight = resize_pos_embed(clip_model.visual.attnpool.positional_embedding.unsqueeze(0), pos_embed) +# clip_model.visual.attnpool.positional_embedding = pos_embed -# End below -print('Loading the model: CompVis/ldm-text2im-large-256') -ldm_pipeline = LDMTextToImagePipeline.from_pretrained("CompVis/ldm-text2im-large-256") +# # End below +# print('Loading the model: CompVis/ldm-text2im-large-256') +# ldm_pipeline = LDMTextToImagePipeline.from_pretrained("CompVis/ldm-text2im-large-256") -def generate_image_from_text(prompt, steps=100, seed=42, guidance_scale=6.0): - print('RUN: generate_image_from_text') - torch.cuda.empty_cache() - generator = torch.manual_seed(seed) - images = ldm_pipeline([prompt], generator=generator, num_inference_steps=steps, eta=0.3, guidance_scale=guidance_scale)["sample"] - return images[0] - -def generate_text_from_image(img): - print('RUN: generate_text_from_image') - with torch.no_grad(): - image = preprocess(img) - image = torch.tensor(np.stack([image])).to(device) - image -= image_mean - image /= image_std +# def generate_image_from_text(prompt, steps=100, seed=42, guidance_scale=6.0): +# print('RUN: generate_image_from_text') +# torch.cuda.empty_cache() +# generator = torch.manual_seed(seed) +# images = ldm_pipeline([prompt], generator=generator, num_inference_steps=steps, eta=0.3, guidance_scale=guidance_scale)["sample"] +# return images[0] + +# def generate_text_from_image(img): +# print('RUN: generate_text_from_image') +# with torch.no_grad(): +# image = preprocess(img) +# image = torch.tensor(np.stack([image])).to(device) +# image -= image_mean +# image /= image_std - tmp_att, tmp_fc = clip_model.encode_image(image) - tmp_att = tmp_att[0].permute(1, 2, 0) - tmp_fc = tmp_fc[0] +# tmp_att, tmp_fc = clip_model.encode_image(image) +# tmp_att = tmp_att[0].permute(1, 2, 0) +# tmp_fc = tmp_fc[0] - att_feat = tmp_att - fc_feat = tmp_fc +# att_feat = tmp_att +# fc_feat = tmp_fc - # Inference configurations - eval_kwargs = {} - eval_kwargs.update(vars(opt)) +# # Inference configurations +# eval_kwargs = {} +# eval_kwargs.update(vars(opt)) - verbose = eval_kwargs.get('verbose', True) - verbose_beam = eval_kwargs.get('verbose_beam', 0) - verbose_loss = eval_kwargs.get('verbose_loss', 1) +# verbose = eval_kwargs.get('verbose', True) +# verbose_beam = eval_kwargs.get('verbose_beam', 0) +# verbose_loss = eval_kwargs.get('verbose_loss', 1) - # dataset = eval_kwargs.get('dataset', 'coco') - beam_size = eval_kwargs.get('beam_size', 1) - sample_n = eval_kwargs.get('sample_n', 1) - remove_bad_endings = eval_kwargs.get('remove_bad_endings', 0) +# # dataset = eval_kwargs.get('dataset', 'coco') +# beam_size = eval_kwargs.get('beam_size', 1) +# sample_n = eval_kwargs.get('sample_n', 1) +# remove_bad_endings = eval_kwargs.get('remove_bad_endings', 0) - with torch.no_grad(): - fc_feats = torch.zeros((1,0)).to(device) - att_feats = att_feat.view(1, 196, 2048).float().to(device) - att_masks = None +# with torch.no_grad(): +# fc_feats = torch.zeros((1,0)).to(device) +# att_feats = att_feat.view(1, 196, 2048).float().to(device) +# att_masks = None - # forward the model to also get generated samples for each image - # Only leave one feature for each image, in case duplicate sample - tmp_eval_kwargs = eval_kwargs.copy() - tmp_eval_kwargs.update({'sample_n': 1}) - seq, seq_logprobs = model( - fc_feats, att_feats, att_masks, opt=tmp_eval_kwargs, mode='sample') - seq = seq.data +# # forward the model to also get generated samples for each image +# # Only leave one feature for each image, in case duplicate sample +# tmp_eval_kwargs = eval_kwargs.copy() +# tmp_eval_kwargs.update({'sample_n': 1}) +# seq, seq_logprobs = model( +# fc_feats, att_feats, att_masks, opt=tmp_eval_kwargs, mode='sample') +# seq = seq.data - sents = utils.decode_sequence(model.vocab, seq) +# sents = utils.decode_sequence(model.vocab, seq) - return sents[0] +# return sents[0] -def generate_drawing_from_image(img, steps=100, seed=42, guidance_scale=6.0): - print('RUN: generate_drawing_from_image') - caption = generate_text_from_image(img) - gen_image = generate_image_from_text(caption, steps=steps, seed=seed, guidance_scale=guidance_scale) - return gen_image +# def generate_drawing_from_image(img, steps=100, seed=42, guidance_scale=6.0): +# print('RUN: generate_drawing_from_image') +# caption = generate_text_from_image(img) +# gen_image = generate_image_from_text(caption, steps=steps, seed=seed, guidance_scale=guidance_scale) +# return gen_image random_seed = random.randint(0, 2147483647) +def test_fn(**kwargs): + return None + gr.Interface( - generate_drawing_from_image, - inputs=[ - gr.Image(type="pil"), - gr.inputs.Slider(1, 100, label='Inference Steps', default=50, step=1), - gr.inputs.Slider(0, 2147483647, label='Seed', default=random_seed, step=1), - gr.inputs.Slider(1.0, 20.0, label='Guidance Scale - how much the prompt will influence the results', default=6.0, step=0.1), - ], - outputs=gr.Image(shape=[256,256], type="pil", elem_id="output_image"), - css="#output_image{width: 256px}", +# generate_drawing_from_image, + test_fn, + inputs=[ + gr.Image(type="pil"), + gr.inputs.Slider(1, 100, label='Inference Steps', default=50, step=1), + gr.inputs.Slider(0, 2147483647, label='Seed', default=random_seed, step=1), + gr.inputs.Slider(1.0, 20.0, label='Guidance Scale - how much the prompt will influence the results', default=6.0, step=0.1), + ], + outputs=gr.Image(shape=[256,256], type="pil", elem_id="output_image"), + css="#output_image{width: 256px}", ).launch() diff --git a/captioning/__init__.py b/captioning/__init__.py deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/captioning/data/__init__.py b/captioning/data/__init__.py deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/captioning/data/dataloader.py b/captioning/data/dataloader.py deleted file mode 100644 index 7f2ed0304bd94db21bbc9fbdc6857beccb8bb621..0000000000000000000000000000000000000000 --- a/captioning/data/dataloader.py +++ /dev/null @@ -1,425 +0,0 @@ -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import json -import h5py -from lmdbdict import lmdbdict -from lmdbdict.methods import DUMPS_FUNC, LOADS_FUNC -import os -import numpy as np -import numpy.random as npr -import random -from functools import partial - -import torch -import torch.utils.data as data - -import multiprocessing -import six - -class HybridLoader: - """ - If db_path is a director, then use normal file loading - If lmdb, then load from lmdb - The loading method depend on extention. - - in_memory: if in_memory is True, we save all the features in memory - For individual np(y|z)s, we don't need to do that because the system will do this for us. - Should be useful for lmdb or h5. - (Copied this idea from vilbert) - """ - def __init__(self, db_path, ext, in_memory=False): - self.db_path = db_path - self.ext = ext - if self.ext == '.npy': - self.loader = lambda x: np.load(six.BytesIO(x)) - else: - def load_npz(x): - x = np.load(six.BytesIO(x)) - return x['feat'] if 'feat' in x else x['z'] # normally it should be 'feat', but under cocotest_bu, the key is saved to be 'z' mistakenly. - self.loader = load_npz - if db_path.endswith('.lmdb'): - self.db_type = 'lmdb' - self.lmdb = lmdbdict(db_path, unsafe=True) - self.lmdb._key_dumps = DUMPS_FUNC['ascii'] - self.lmdb._value_loads = LOADS_FUNC['identity'] - elif db_path.endswith('.pth'): # Assume a key,value dictionary - self.db_type = 'pth' - self.feat_file = torch.load(db_path) - self.loader = lambda x: x - print('HybridLoader: ext is ignored') - elif db_path.endswith('h5'): - self.db_type = 'h5' - self.loader = lambda x: np.array(x).astype('float32') - else: - self.db_type = 'dir' - - self.in_memory = in_memory - if self.in_memory: - self.features = {} - - def get(self, key): - - if self.in_memory and key in self.features: - # We save f_input because we want to save the - # compressed bytes to save memory - f_input = self.features[key] - elif self.db_type == 'lmdb': - f_input = self.lmdb[key] - elif self.db_type == 'pth': - f_input = self.feat_file[key] - elif self.db_type == 'h5': - f_input = h5py.File(self.db_path, 'r')[key] - else: - f_input = open(os.path.join(self.db_path, key + self.ext), 'rb').read() - - if self.in_memory and key not in self.features: - self.features[key] = f_input - - # load image - feat = self.loader(f_input) - - return feat - -class Dataset(data.Dataset): - - def get_vocab_size(self): - return self.vocab_size - - def get_vocab(self): - return self.ix_to_word - - def get_seq_length(self): - return self.seq_length - - def __init__(self, opt): - self.opt = opt - self.seq_per_img = opt.seq_per_img - - # feature related options - self.use_fc = getattr(opt, 'use_fc', True) - self.use_att = getattr(opt, 'use_att', True) - self.use_box = getattr(opt, 'use_box', 0) - self.norm_att_feat = getattr(opt, 'norm_att_feat', 0) - self.norm_box_feat = getattr(opt, 'norm_box_feat', 0) - - # load the json file which contains additional information about the dataset - print('DataLoader loading json file: ', opt.input_json) - self.info = json.load(open(self.opt.input_json)) - if 'ix_to_word' in self.info: - self.ix_to_word = self.info['ix_to_word'] - self.vocab_size = len(self.ix_to_word) - print('vocab size is ', self.vocab_size) - - # open the hdf5 file - print('DataLoader loading h5 file: ', opt.input_fc_dir, opt.input_att_dir, opt.input_box_dir, opt.input_label_h5) - """ - Setting input_label_h5 to none is used when only doing generation. - For example, when you need to test on coco test set. - """ - if self.opt.input_label_h5 != 'none': - self.h5_label_file = h5py.File(self.opt.input_label_h5, 'r', driver='core') - # load in the sequence data - seq_size = self.h5_label_file['labels'].shape - self.label = self.h5_label_file['labels'][:] - self.seq_length = seq_size[1] - print('max sequence length in data is', self.seq_length) - # load the pointers in full to RAM (should be small enough) - self.label_start_ix = self.h5_label_file['label_start_ix'][:] - self.label_end_ix = self.h5_label_file['label_end_ix'][:] - else: - self.seq_length = 1 - - self.data_in_memory = getattr(opt, 'data_in_memory', False) - self.fc_loader = HybridLoader(self.opt.input_fc_dir, '.npy', in_memory=self.data_in_memory) - self.att_loader = HybridLoader(self.opt.input_att_dir, '.npz', in_memory=self.data_in_memory) - self.box_loader = HybridLoader(self.opt.input_box_dir, '.npy', in_memory=self.data_in_memory) - - self.num_images = len(self.info['images']) # self.label_start_ix.shape[0] - print('read %d image features' %(self.num_images)) - - # separate out indexes for each of the provided splits - self.split_ix = {'train': [], 'val': [], 'test': []} - for ix in range(len(self.info['images'])): - img = self.info['images'][ix] - if not 'split' in img: - self.split_ix['train'].append(ix) - self.split_ix['val'].append(ix) - self.split_ix['test'].append(ix) - elif img['split'] == 'train': - self.split_ix['train'].append(ix) - elif img['split'] == 'val': - self.split_ix['val'].append(ix) - elif img['split'] == 'test': - self.split_ix['test'].append(ix) - elif opt.train_only == 0: # restval - self.split_ix['train'].append(ix) - - print('assigned %d images to split train' %len(self.split_ix['train'])) - print('assigned %d images to split val' %len(self.split_ix['val'])) - print('assigned %d images to split test' %len(self.split_ix['test'])) - - def get_captions(self, ix, seq_per_img): - # fetch the sequence labels - ix1 = self.label_start_ix[ix] - 1 #label_start_ix starts from 1 - ix2 = self.label_end_ix[ix] - 1 - ncap = ix2 - ix1 + 1 # number of captions available for this image - assert ncap > 0, 'an image does not have any label. this can be handled but right now isn\'t' - - if ncap < seq_per_img: - # we need to subsample (with replacement) - seq = np.zeros([seq_per_img, self.seq_length], dtype = 'int') - for q in range(seq_per_img): - ixl = random.randint(ix1,ix2) - seq[q, :] = self.label[ixl, :self.seq_length] - else: - ixl = random.randint(ix1, ix2 - seq_per_img + 1) - seq = self.label[ixl: ixl + seq_per_img, :self.seq_length] - - return seq - - def collate_func(self, batch, split): - seq_per_img = self.seq_per_img - - fc_batch = [] - att_batch = [] - label_batch = [] - - wrapped = False - - infos = [] - gts = [] - - for sample in batch: - # fetch image - tmp_fc, tmp_att, tmp_seq, \ - ix, it_pos_now, tmp_wrapped = sample - if tmp_wrapped: - wrapped = True - - fc_batch.append(tmp_fc) - att_batch.append(tmp_att) - - tmp_label = np.zeros([seq_per_img, self.seq_length + 2], dtype = 'int') - if hasattr(self, 'h5_label_file'): - # if there is ground truth - tmp_label[:, 1 : self.seq_length + 1] = tmp_seq - label_batch.append(tmp_label) - - # Used for reward evaluation - if hasattr(self, 'h5_label_file'): - # if there is ground truth - gts.append(self.label[self.label_start_ix[ix] - 1: self.label_end_ix[ix]]) - else: - gts.append([]) - - # record associated info as well - info_dict = {} - info_dict['ix'] = ix - info_dict['id'] = self.info['images'][ix]['id'] - info_dict['file_path'] = self.info['images'][ix].get('file_path', '') - infos.append(info_dict) - - # #sort by att_feat length - # fc_batch, att_batch, label_batch, gts, infos = \ - # zip(*sorted(zip(fc_batch, att_batch, np.vsplit(label_batch, batch_size), gts, infos), key=lambda x: len(x[1]), reverse=True)) - fc_batch, att_batch, label_batch, gts, infos = \ - zip(*sorted(zip(fc_batch, att_batch, label_batch, gts, infos), key=lambda x: 0, reverse=True)) - data = {} - data['fc_feats'] = np.stack(fc_batch) - # merge att_feats - max_att_len = max([_.shape[0] for _ in att_batch]) - data['att_feats'] = np.zeros([len(att_batch), max_att_len, att_batch[0].shape[1]], dtype = 'float32') - for i in range(len(att_batch)): - data['att_feats'][i, :att_batch[i].shape[0]] = att_batch[i] - data['att_masks'] = np.zeros(data['att_feats'].shape[:2], dtype='float32') - for i in range(len(att_batch)): - data['att_masks'][i, :att_batch[i].shape[0]] = 1 - # set att_masks to None if attention features have same length - if data['att_masks'].sum() == data['att_masks'].size: - data['att_masks'] = None - - data['labels'] = np.vstack(label_batch) - # generate mask - nonzeros = np.array(list(map(lambda x: (x != 0).sum()+2, data['labels']))) - mask_batch = np.zeros([data['labels'].shape[0], self.seq_length + 2], dtype = 'float32') - for ix, row in enumerate(mask_batch): - row[:nonzeros[ix]] = 1 - data['masks'] = mask_batch - data['labels'] = data['labels'].reshape(len(batch), seq_per_img, -1) - data['masks'] = data['masks'].reshape(len(batch), seq_per_img, -1) - - data['gts'] = gts # all ground truth captions of each images - data['bounds'] = {'it_pos_now': it_pos_now, # the it_pos_now of the last sample - 'it_max': len(self.split_ix[split]), 'wrapped': wrapped} - data['infos'] = infos - - data = {k:torch.from_numpy(v) if type(v) is np.ndarray else v for k,v in data.items()} # Turn all ndarray to torch tensor - - return data - - def __getitem__(self, index): - """This function returns a tuple that is further passed to collate_fn - """ - ix, it_pos_now, wrapped = index #self.split_ix[index] - if self.use_att: - att_feat = self.att_loader.get(str(self.info['images'][ix]['id'])) - # Reshape to K x C - att_feat = att_feat.reshape(-1, att_feat.shape[-1]) - if self.norm_att_feat: - att_feat = att_feat / np.linalg.norm(att_feat, 2, 1, keepdims=True) - if self.use_box: - box_feat = self.box_loader.get(str(self.info['images'][ix]['id'])) - # devided by image width and height - x1,y1,x2,y2 = np.hsplit(box_feat, 4) - h,w = self.info['images'][ix]['height'], self.info['images'][ix]['width'] - box_feat = np.hstack((x1/w, y1/h, x2/w, y2/h, (x2-x1)*(y2-y1)/(w*h))) # question? x2-x1+1?? - if self.norm_box_feat: - box_feat = box_feat / np.linalg.norm(box_feat, 2, 1, keepdims=True) - att_feat = np.hstack([att_feat, box_feat]) - # sort the features by the size of boxes - att_feat = np.stack(sorted(att_feat, key=lambda x:x[-1], reverse=True)) - else: - att_feat = np.zeros((0,0), dtype='float32') - if self.use_fc: - try: - fc_feat = self.fc_loader.get(str(self.info['images'][ix]['id'])) - except: - # Use average of attention when there is no fc provided (For bottomup feature) - fc_feat = att_feat.mean(0) - else: - fc_feat = np.zeros((0), dtype='float32') - if hasattr(self, 'h5_label_file'): - seq = self.get_captions(ix, self.seq_per_img) - else: - seq = None - return (fc_feat, - att_feat, seq, - ix, it_pos_now, wrapped) - - def __len__(self): - return len(self.info['images']) - -class DataLoader: - def __init__(self, opt): - self.opt = opt - self.batch_size = self.opt.batch_size - self.dataset = Dataset(opt) - - # Initialize loaders and iters - self.loaders, self.iters = {}, {} - for split in ['train', 'val', 'test']: - if split == 'train': - sampler = MySampler(self.dataset.split_ix[split], shuffle=True, wrap=True) - else: - sampler = MySampler(self.dataset.split_ix[split], shuffle=False, wrap=False) - self.loaders[split] = data.DataLoader(dataset=self.dataset, - batch_size=self.batch_size, - sampler=sampler, - pin_memory=True, - num_workers=4, # 4 is usually enough - collate_fn=partial(self.dataset.collate_func, split=split), - drop_last=False) - self.iters[split] = iter(self.loaders[split]) - - def get_batch(self, split): - try: - data = next(self.iters[split]) - except StopIteration: - self.iters[split] = iter(self.loaders[split]) - data = next(self.iters[split]) - return data - - def reset_iterator(self, split): - self.loaders[split].sampler._reset_iter() - self.iters[split] = iter(self.loaders[split]) - - def get_vocab_size(self): - return self.dataset.get_vocab_size() - - @property - def vocab_size(self): - return self.get_vocab_size() - - def get_vocab(self): - return self.dataset.get_vocab() - - def get_seq_length(self): - return self.dataset.get_seq_length() - - @property - def seq_length(self): - return self.get_seq_length() - - def state_dict(self): - def get_prefetch_num(split): - if self.loaders[split].num_workers > 0: - return (self.iters[split]._send_idx - self.iters[split]._rcvd_idx) * self.batch_size - else: - return 0 - return {split: loader.sampler.state_dict(get_prefetch_num(split)) \ - for split, loader in self.loaders.items()} - - def load_state_dict(self, state_dict=None): - if state_dict is None: - return - for split in self.loaders.keys(): - self.loaders[split].sampler.load_state_dict(state_dict[split]) - - -class MySampler(data.sampler.Sampler): - def __init__(self, index_list, shuffle, wrap): - self.index_list = index_list - self.shuffle = shuffle - self.wrap = wrap - # if wrap, there will be not stop iteration called - # wrap True used during training, and wrap False used during test. - self._reset_iter() - - def __iter__(self): - return self - - def __next__(self): - wrapped = False - if self.iter_counter == len(self._index_list): - self._reset_iter() - if self.wrap: - wrapped = True - else: - raise StopIteration() - if len(self._index_list) == 0: # overflow when 0 samples - return None - elem = (self._index_list[self.iter_counter], self.iter_counter+1, wrapped) - self.iter_counter += 1 - return elem - - def next(self): - return self.__next__() - - def _reset_iter(self): - if self.shuffle: - rand_perm = npr.permutation(len(self.index_list)) - self._index_list = [self.index_list[_] for _ in rand_perm] - else: - self._index_list = self.index_list - - self.iter_counter = 0 - - def __len__(self): - return len(self.index_list) - - def load_state_dict(self, state_dict=None): - if state_dict is None: - return - self._index_list = state_dict['index_list'] - self.iter_counter = state_dict['iter_counter'] - - def state_dict(self, prefetched_num=None): - prefetched_num = prefetched_num or 0 - return { - 'index_list': self._index_list, - 'iter_counter': self.iter_counter - prefetched_num - } - - \ No newline at end of file diff --git a/captioning/data/pth_loader.py b/captioning/data/pth_loader.py deleted file mode 100644 index 28023699735470daa7e2ab4752a31ea8282c04c5..0000000000000000000000000000000000000000 --- a/captioning/data/pth_loader.py +++ /dev/null @@ -1,334 +0,0 @@ -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import json -import h5py -from lmdbdict import lmdbdict -from lmdbdict.methods import DUMPS_FUNC, LOADS_FUNC -import os -import numpy as np -import numpy.random as npr -import random - -import torch -import torch.utils.data as data - -import multiprocessing -import six - -verbose = True -# import torch -# if torch.cuda.current_device() in [0, -1]: -if 'LOCAL_RANK' in os.environ and os.environ['LOCAL_RANK'] != '0': - verbose = False - -class HybridLoader: - """ - If db_path is a director, then use normal file loading - If lmdb, then load from lmdb - The loading method depend on extention. - - in_memory: if in_memory is True, we save all the features in memory - For individual np(y|z)s, we don't need to do that because the system will do this for us. - Should be useful for lmdb or h5. - (Copied this idea from vilbert) - """ - def __init__(self, db_path, ext, in_memory=False): - self.db_path = db_path - self.ext = ext - if self.ext == '.npy': - self.loader = lambda x: np.load(six.BytesIO(x)) - else: - self.loader = lambda x: np.load(six.BytesIO(x))['feat'] - if db_path.endswith('.lmdb'): - self.db_type = 'lmdb' - self.lmdb = lmdbdict(db_path, unsafe=True) - self.lmdb._key_dumps = DUMPS_FUNC['ascii'] - self.lmdb._value_loads = LOADS_FUNC['identity'] - elif db_path.endswith('.pth'): # Assume a key,value dictionary - self.db_type = 'pth' - self.feat_file = torch.load(db_path) - self.loader = lambda x: x - print('HybridLoader: ext is ignored') - elif db_path.endswith('h5'): - self.db_type = 'h5' - self.loader = lambda x: np.array(x).astype('float32') - else: - self.db_type = 'dir' - - self.in_memory = in_memory - if self.in_memory: - self.features = {} - - def get(self, key): - - if self.in_memory and key in self.features: - # We save f_input because we want to save the - # compressed bytes to save memory - f_input = self.features[key] - elif self.db_type == 'lmdb': - f_input = self.lmdb[key] - elif self.db_type == 'pth': - f_input = self.feat_file[key] - elif self.db_type == 'h5': - f_input = h5py.File(self.db_path, 'r')[key] - else: - f_input = open(os.path.join(self.db_path, key + self.ext), 'rb').read() - - if self.in_memory and key not in self.features: - self.features[key] = f_input - - # load image - feat = self.loader(f_input) - - return feat - -class CaptionDataset(data.Dataset): - - def get_vocab_size(self): - return self.vocab_size - - def get_vocab(self): - return self.ix_to_word - - def get_seq_length(self): - return self.seq_length - - def __init__(self, opt): - self.opt = opt - self.seq_per_img = opt.seq_per_img - - # feature related options - self.use_fc = getattr(opt, 'use_fc', True) - self.use_att = getattr(opt, 'use_att', True) - self.use_box = getattr(opt, 'use_box', 0) - self.norm_att_feat = getattr(opt, 'norm_att_feat', 0) - self.norm_box_feat = getattr(opt, 'norm_box_feat', 0) - - # load the json file which contains additional information about the dataset - if verbose: - print('DataLoader loading json file: ', opt.input_json) - self.info = json.load(open(self.opt.input_json)) - if 'ix_to_word' in self.info: - self.ix_to_word = self.info['ix_to_word'] - self.vocab_size = len(self.ix_to_word) - if verbose: - print('vocab size is ', self.vocab_size) - - # open the hdf5 file - if verbose: - print('DataLoader loading h5 file: ', opt.input_fc_dir, opt.input_att_dir, opt.input_box_dir, opt.input_label_h5) - """ - Setting input_label_h5 to none is used when only doing generation. - For example, when you need to test on coco test set. - """ - if self.opt.input_label_h5 != 'none': - self.h5_label_file = h5py.File(self.opt.input_label_h5, 'r', driver='core') - # load in the sequence data - seq_size = self.h5_label_file['labels'].shape - self.label = self.h5_label_file['labels'][:] - self.seq_length = seq_size[1] - if verbose: - print('max sequence length in data is', self.seq_length) - # load the pointers in full to RAM (should be small enough) - self.label_start_ix = self.h5_label_file['label_start_ix'][:] - self.label_end_ix = self.h5_label_file['label_end_ix'][:] - else: - self.seq_length = 1 - - self.data_in_memory = getattr(opt, 'data_in_memory', False) - self.fc_loader = HybridLoader(self.opt.input_fc_dir, '.npy', in_memory=self.data_in_memory) - self.att_loader = HybridLoader(self.opt.input_att_dir, '.npz', in_memory=self.data_in_memory) - self.box_loader = HybridLoader(self.opt.input_box_dir, '.npy', in_memory=self.data_in_memory) - - self.use_clipscore = getattr(opt, 'use_clipscore', False) - # if self.use_clipscore: - self.clipscore_loader = HybridLoader(self.opt.input_clipscore_vis_dir, '.npy', in_memory=self.data_in_memory) - - - self.num_images = len(self.info['images']) # self.label_start_ix.shape[0] - if verbose: - print('read %d image features' %(self.num_images)) - - # separate out indexes for each of the provided splits - self.split_ix = {'train': [], 'val': [], 'test': []} - for ix in range(len(self.info['images'])): - img = self.info['images'][ix] - if not 'split' in img: - self.split_ix['train'].append(ix) - self.split_ix['val'].append(ix) - self.split_ix['test'].append(ix) - elif img['split'] == 'train': - self.split_ix['train'].append(ix) - elif img['split'] == 'val': - self.split_ix['val'].append(ix) - elif img['split'] == 'test': - self.split_ix['test'].append(ix) - elif opt.train_only == 0: # restval - self.split_ix['train'].append(ix) - - if verbose: - print('assigned %d images to split train' %len(self.split_ix['train'])) - print('assigned %d images to split val' %len(self.split_ix['val'])) - print('assigned %d images to split test' %len(self.split_ix['test'])) - - def get_captions(self, ix, seq_per_img): - # fetch the sequence labels - ix1 = self.label_start_ix[ix] - 1 #label_start_ix starts from 1 - ix2 = self.label_end_ix[ix] - 1 - ncap = ix2 - ix1 + 1 # number of captions available for this image - assert ncap > 0, 'an image does not have any label. this can be handled but right now isn\'t' - - if ncap < seq_per_img: - # we need to subsample (with replacement) - seq = np.zeros([seq_per_img, self.seq_length], dtype = 'int') - for q in range(seq_per_img): - ixl = random.randint(ix1,ix2) - seq[q, :] = self.label[ixl, :self.seq_length] - else: - ixl = random.randint(ix1, ix2 - seq_per_img + 1) - seq = self.label[ixl: ixl + seq_per_img, :self.seq_length] - - return seq - - def collate_func(self, batch): - seq_per_img = self.seq_per_img - - fc_batch = [] - att_batch = [] - label_batch = [] - - clip_vis_feat_batch = [] - - wrapped = False - - infos = [] - gts = [] - - for sample in batch: - # fetch image - # if self.use_clipscore: - tmp_fc, tmp_att, tmp_seq, \ - ix, tmp_clip_vis_feat = sample - - clip_vis_feat_batch.append(tmp_clip_vis_feat) - # else: - # tmp_fc, tmp_att, tmp_seq, \ - # ix = sample - - fc_batch.append(tmp_fc) - att_batch.append(tmp_att) - - tmp_label = np.zeros([seq_per_img, self.seq_length + 2], dtype = 'int') - if hasattr(self, 'h5_label_file'): - # if there is ground truth - tmp_label[:, 1 : self.seq_length + 1] = tmp_seq - label_batch.append(tmp_label) - - # Used for reward evaluation - if hasattr(self, 'h5_label_file'): - # if there is ground truth - gts.append(self.label[self.label_start_ix[ix] - 1: self.label_end_ix[ix]]) - else: - gts.append([]) - - # record associated info as well - info_dict = {} - info_dict['ix'] = ix - info_dict['id'] = self.info['images'][ix]['id'] - info_dict['file_path'] = self.info['images'][ix].get('file_path', '') - infos.append(info_dict) - - # #sort by att_feat length - # fc_batch, att_batch, label_batch, gts, infos = \ - # zip(*sorted(zip(fc_batch, att_batch, np.vsplit(label_batch, batch_size), gts, infos), key=lambda x: len(x[1]), reverse=True)) - if self.use_clipscore: - fc_batch, att_batch, label_batch, clip_vis_feat_batch, gts, infos = \ - zip(*sorted(zip(fc_batch, att_batch, label_batch, clip_vis_feat_batch, gts, infos), key=lambda x: 0, reverse=True)) - else: - fc_batch, att_batch, label_batch, gts, infos = \ - zip(*sorted(zip(fc_batch, att_batch, label_batch, gts, infos), key=lambda x: 0, reverse=True)) - data = {} - data['fc_feats'] = np.stack(fc_batch) - # merge att_feats - max_att_len = max([_.shape[0] for _ in att_batch]) - data['att_feats'] = np.zeros([len(att_batch), max_att_len, att_batch[0].shape[1]], dtype = 'float32') - for i in range(len(att_batch)): - data['att_feats'][i, :att_batch[i].shape[0]] = att_batch[i] - data['att_masks'] = np.zeros(data['att_feats'].shape[:2], dtype='float32') - for i in range(len(att_batch)): - data['att_masks'][i, :att_batch[i].shape[0]] = 1 - # set att_masks to None if attention features have same length - if data['att_masks'].sum() == data['att_masks'].size: - data['att_masks'] = None - - # if self.use_clipscore: - data['clip_vis_feats'] = np.stack(clip_vis_feat_batch) - - data['labels'] = np.vstack(label_batch) - # generate mask - nonzeros = np.array(list(map(lambda x: (x != 0).sum()+2, data['labels']))) - mask_batch = np.zeros([data['labels'].shape[0], self.seq_length + 2], dtype = 'float32') - for ix, row in enumerate(mask_batch): - row[:nonzeros[ix]] = 1 - data['masks'] = mask_batch - data['labels'] = data['labels'].reshape(len(batch), seq_per_img, -1) - data['masks'] = data['masks'].reshape(len(batch), seq_per_img, -1) - - data['gts'] = gts # all ground truth captions of each images - data['infos'] = infos - - data = {k:torch.from_numpy(v) if type(v) is np.ndarray else v for k,v in data.items()} # Turn all ndarray to torch tensor - - return data - - def __getitem__(self, ix): - """This function returns a tuple that is further passed to collate_fn - """ - if self.use_att: - att_feat = self.att_loader.get(str(self.info['images'][ix]['id'])) - # Reshape to K x C - att_feat = att_feat.reshape(-1, att_feat.shape[-1]) - if self.norm_att_feat: - att_feat = att_feat / np.linalg.norm(att_feat, 2, 1, keepdims=True) - if self.use_box: - box_feat = self.box_loader.get(str(self.info['images'][ix]['id'])) - # devided by image width and height - x1,y1,x2,y2 = np.hsplit(box_feat, 4) - h,w = self.info['images'][ix]['height'], self.info['images'][ix]['width'] - box_feat = np.hstack((x1/w, y1/h, x2/w, y2/h, (x2-x1)*(y2-y1)/(w*h))) # question? x2-x1+1?? - if self.norm_box_feat: - box_feat = box_feat / np.linalg.norm(box_feat, 2, 1, keepdims=True) - att_feat = np.hstack([att_feat, box_feat]) - # sort the features by the size of boxes - att_feat = np.stack(sorted(att_feat, key=lambda x:x[-1], reverse=True)) - else: - att_feat = np.zeros((0,0), dtype='float32') - if self.use_fc: - try: - fc_feat = self.fc_loader.get(str(self.info['images'][ix]['id'])) - except: - # Use average of attention when there is no fc provided (For bottomup feature) - fc_feat = att_feat.mean(0) - else: - fc_feat = np.zeros((0), dtype='float32') - if hasattr(self, 'h5_label_file'): - seq = self.get_captions(ix, self.seq_per_img) - else: - seq = None - - # if self.use_clipscore: - clip_vis_feat = self.clipscore_loader.get( - str(self.info['images'][ix]['id'])) - - return (fc_feat, - att_feat, seq, - ix, clip_vis_feat) - - # return (fc_feat, - # att_feat, seq, - # ix) - - def __len__(self): - return len(self.info['images']) diff --git a/captioning/data/pth_loader_FineCapEval.py b/captioning/data/pth_loader_FineCapEval.py deleted file mode 100644 index 388301edd763d54d95675ca2ed6eb502f77e1eb1..0000000000000000000000000000000000000000 --- a/captioning/data/pth_loader_FineCapEval.py +++ /dev/null @@ -1,334 +0,0 @@ -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import json -import h5py -from lmdbdict import lmdbdict -from lmdbdict.methods import DUMPS_FUNC, LOADS_FUNC -import os -import numpy as np -import numpy.random as npr -import random - -import torch -import torch.utils.data as data - -import multiprocessing -import six - -verbose = True -# import torch -# if torch.cuda.current_device() in [0, -1]: -if 'LOCAL_RANK' in os.environ and os.environ['LOCAL_RANK'] != '0': - verbose = False - -class HybridLoader: - """ - If db_path is a director, then use normal file loading - If lmdb, then load from lmdb - The loading method depend on extention. - - in_memory: if in_memory is True, we save all the features in memory - For individual np(y|z)s, we don't need to do that because the system will do this for us. - Should be useful for lmdb or h5. - (Copied this idea from vilbert) - """ - def __init__(self, db_path, ext, in_memory=False): - self.db_path = db_path - self.ext = ext - if self.ext == '.npy': - self.loader = lambda x: np.load(six.BytesIO(x)) - else: - self.loader = lambda x: np.load(six.BytesIO(x))['feat'] - if db_path.endswith('.lmdb'): - self.db_type = 'lmdb' - self.lmdb = lmdbdict(db_path, unsafe=True) - self.lmdb._key_dumps = DUMPS_FUNC['ascii'] - self.lmdb._value_loads = LOADS_FUNC['identity'] - elif db_path.endswith('.pth'): # Assume a key,value dictionary - self.db_type = 'pth' - self.feat_file = torch.load(db_path) - self.loader = lambda x: x - print('HybridLoader: ext is ignored') - elif db_path.endswith('h5'): - self.db_type = 'h5' - self.loader = lambda x: np.array(x).astype('float32') - else: - self.db_type = 'dir' - - self.in_memory = in_memory - if self.in_memory: - self.features = {} - - def get(self, key): - - if self.in_memory and key in self.features: - # We save f_input because we want to save the - # compressed bytes to save memory - f_input = self.features[key] - elif self.db_type == 'lmdb': - f_input = self.lmdb[key] - elif self.db_type == 'pth': - f_input = self.feat_file[key] - elif self.db_type == 'h5': - f_input = h5py.File(self.db_path, 'r')[key] - else: - f_input = open(os.path.join(self.db_path, key + self.ext), 'rb').read() - - if self.in_memory and key not in self.features: - self.features[key] = f_input - - # load image - feat = self.loader(f_input) - - return feat - -class CaptionDataset(data.Dataset): - - def get_vocab_size(self): - return self.vocab_size - - def get_vocab(self): - return self.ix_to_word - - def get_seq_length(self): - return self.seq_length - - def __init__(self, opt): - self.opt = opt - self.seq_per_img = opt.seq_per_img - - # feature related options - self.use_fc = getattr(opt, 'use_fc', True) - self.use_att = getattr(opt, 'use_att', True) - self.use_box = getattr(opt, 'use_box', 0) - self.norm_att_feat = getattr(opt, 'norm_att_feat', 0) - self.norm_box_feat = getattr(opt, 'norm_box_feat', 0) - - # load the json file which contains additional information about the dataset - if verbose: - print('DataLoader loading json file: ', opt.input_json) - self.info = json.load(open(self.opt.input_json)) - if 'ix_to_word' in self.info: - self.ix_to_word = self.info['ix_to_word'] - self.vocab_size = len(self.ix_to_word) - if verbose: - print('vocab size is ', self.vocab_size) - - # open the hdf5 file - if verbose: - print('DataLoader loading h5 file: ', opt.input_fc_dir, opt.input_att_dir, opt.input_box_dir, opt.input_label_h5) - """ - Setting input_label_h5 to none is used when only doing generation. - For example, when you need to test on coco test set. - """ - if self.opt.input_label_h5 != 'none': - self.h5_label_file = h5py.File(self.opt.input_label_h5, 'r', driver='core') - # load in the sequence data - seq_size = self.h5_label_file['labels'].shape - self.label = self.h5_label_file['labels'][:] - self.seq_length = seq_size[1] - if verbose: - print('max sequence length in data is', self.seq_length) - # load the pointers in full to RAM (should be small enough) - self.label_start_ix = self.h5_label_file['label_start_ix'][:] - self.label_end_ix = self.h5_label_file['label_end_ix'][:] - else: - self.seq_length = 1 - - self.data_in_memory = getattr(opt, 'data_in_memory', False) - self.fc_loader = HybridLoader(self.opt.input_fc_dir, '.npy', in_memory=self.data_in_memory) - self.att_loader = HybridLoader(self.opt.input_att_dir, '.npz', in_memory=self.data_in_memory) - self.box_loader = HybridLoader(self.opt.input_box_dir, '.npy', in_memory=self.data_in_memory) - - self.use_clipscore = getattr(opt, 'use_clipscore', False) - if self.use_clipscore: - self.clipscore_loader = HybridLoader(self.opt.input_clipscore_vis_dir, '.npy', in_memory=self.data_in_memory) - - - self.num_images = len(self.info['images']) # self.label_start_ix.shape[0] - if verbose: - print('read %d image features' %(self.num_images)) - - # separate out indexes for each of the provided splits - self.split_ix = {'train': [], 'val': [], 'test': []} - for ix in range(len(self.info['images'])): - img = self.info['images'][ix] - if not 'split' in img: - self.split_ix['train'].append(ix) - self.split_ix['val'].append(ix) - self.split_ix['test'].append(ix) - elif img['split'] == 'train': - self.split_ix['train'].append(ix) - elif img['split'] == 'val': - self.split_ix['val'].append(ix) - elif img['split'] == 'test': - self.split_ix['test'].append(ix) - elif opt.train_only == 0: # restval - self.split_ix['train'].append(ix) - - if verbose: - print('assigned %d images to split train' %len(self.split_ix['train'])) - print('assigned %d images to split val' %len(self.split_ix['val'])) - print('assigned %d images to split test' %len(self.split_ix['test'])) - - def get_captions(self, ix, seq_per_img): - # fetch the sequence labels - ix1 = self.label_start_ix[ix] - 1 #label_start_ix starts from 1 - ix2 = self.label_end_ix[ix] - 1 - ncap = ix2 - ix1 + 1 # number of captions available for this image - assert ncap > 0, 'an image does not have any label. this can be handled but right now isn\'t' - - if ncap < seq_per_img: - # we need to subsample (with replacement) - seq = np.zeros([seq_per_img, self.seq_length], dtype = 'int') - for q in range(seq_per_img): - ixl = random.randint(ix1,ix2) - seq[q, :] = self.label[ixl, :self.seq_length] - else: - ixl = random.randint(ix1, ix2 - seq_per_img + 1) - seq = self.label[ixl: ixl + seq_per_img, :self.seq_length] - - return seq - - def collate_func(self, batch): - seq_per_img = self.seq_per_img - - fc_batch = [] - att_batch = [] - label_batch = [] - - clip_vis_feat_batch = [] - - wrapped = False - - infos = [] - gts = [] - - for sample in batch: - # fetch image - if self.use_clipscore: - tmp_fc, tmp_att, tmp_seq, \ - ix, tmp_clip_vis_feat = sample - - clip_vis_feat_batch.append(tmp_clip_vis_feat) - else: - tmp_fc, tmp_att, tmp_seq, \ - ix = sample - - fc_batch.append(tmp_fc) - att_batch.append(tmp_att) - - tmp_label = np.zeros([seq_per_img, self.seq_length + 2], dtype = 'int') - if hasattr(self, 'h5_label_file'): - # if there is ground truth - tmp_label[:, 1 : self.seq_length + 1] = tmp_seq - label_batch.append(tmp_label) - - # Used for reward evaluation - if hasattr(self, 'h5_label_file'): - # if there is ground truth - gts.append(self.label[self.label_start_ix[ix] - 1: self.label_end_ix[ix]]) - else: - gts.append([]) - - # record associated info as well - info_dict = {} - info_dict['ix'] = ix - info_dict['id'] = self.info['images'][ix]['id'] - info_dict['file_path'] = self.info['images'][ix].get('file_path', '') - infos.append(info_dict) - - # #sort by att_feat length - # fc_batch, att_batch, label_batch, gts, infos = \ - # zip(*sorted(zip(fc_batch, att_batch, np.vsplit(label_batch, batch_size), gts, infos), key=lambda x: len(x[1]), reverse=True)) - if self.use_clipscore: - fc_batch, att_batch, label_batch, clip_vis_feat_batch, gts, infos = \ - zip(*sorted(zip(fc_batch, att_batch, label_batch, clip_vis_feat_batch, gts, infos), key=lambda x: 0, reverse=True)) - else: - fc_batch, att_batch, label_batch, gts, infos = \ - zip(*sorted(zip(fc_batch, att_batch, label_batch, gts, infos), key=lambda x: 0, reverse=True)) - data = {} - data['fc_feats'] = np.stack(fc_batch) - # merge att_feats - max_att_len = max([_.shape[0] for _ in att_batch]) - data['att_feats'] = np.zeros([len(att_batch), max_att_len, att_batch[0].shape[1]], dtype = 'float32') - for i in range(len(att_batch)): - data['att_feats'][i, :att_batch[i].shape[0]] = att_batch[i] - data['att_masks'] = np.zeros(data['att_feats'].shape[:2], dtype='float32') - for i in range(len(att_batch)): - data['att_masks'][i, :att_batch[i].shape[0]] = 1 - # set att_masks to None if attention features have same length - if data['att_masks'].sum() == data['att_masks'].size: - data['att_masks'] = None - - if self.use_clipscore: - data['clip_vis_feats'] = np.stack(clip_vis_feat_batch) - - data['labels'] = np.vstack(label_batch) - # generate mask - nonzeros = np.array(list(map(lambda x: (x != 0).sum()+2, data['labels']))) - mask_batch = np.zeros([data['labels'].shape[0], self.seq_length + 2], dtype = 'float32') - for ix, row in enumerate(mask_batch): - row[:nonzeros[ix]] = 1 - data['masks'] = mask_batch - data['labels'] = data['labels'].reshape(len(batch), seq_per_img, -1) - data['masks'] = data['masks'].reshape(len(batch), seq_per_img, -1) - - data['gts'] = gts # all ground truth captions of each images - data['infos'] = infos - - data = {k:torch.from_numpy(v) if type(v) is np.ndarray else v for k,v in data.items()} # Turn all ndarray to torch tensor - - return data - - def __getitem__(self, ix): - """This function returns a tuple that is further passed to collate_fn - """ - if self.use_att: - att_feat = self.att_loader.get(str(self.info['images'][ix]['id'])) - # Reshape to K x C - att_feat = att_feat.reshape(-1, att_feat.shape[-1]) - if self.norm_att_feat: - att_feat = att_feat / np.linalg.norm(att_feat, 2, 1, keepdims=True) - if self.use_box: - box_feat = self.box_loader.get(str(self.info['images'][ix]['id'])) - # devided by image width and height - x1,y1,x2,y2 = np.hsplit(box_feat, 4) - h,w = self.info['images'][ix]['height'], self.info['images'][ix]['width'] - box_feat = np.hstack((x1/w, y1/h, x2/w, y2/h, (x2-x1)*(y2-y1)/(w*h))) # question? x2-x1+1?? - if self.norm_box_feat: - box_feat = box_feat / np.linalg.norm(box_feat, 2, 1, keepdims=True) - att_feat = np.hstack([att_feat, box_feat]) - # sort the features by the size of boxes - att_feat = np.stack(sorted(att_feat, key=lambda x:x[-1], reverse=True)) - else: - att_feat = np.zeros((0,0), dtype='float32') - if self.use_fc: - try: - fc_feat = self.fc_loader.get(str(self.info['images'][ix]['id'])) - except: - # Use average of attention when there is no fc provided (For bottomup feature) - fc_feat = att_feat.mean(0) - else: - fc_feat = np.zeros((0), dtype='float32') - if hasattr(self, 'h5_label_file'): - seq = self.get_captions(ix, self.seq_per_img) - else: - seq = None - - if self.use_clipscore: - clip_vis_feat = self.clipscore_loader.get( - str(self.info['images'][ix]['id'])) - - return (fc_feat, - att_feat, seq, - ix, clip_vis_feat) - - return (fc_feat, - att_feat, seq, - ix) - - def __len__(self): - return len(self.info['images']) diff --git a/captioning/models/AoAModel.py b/captioning/models/AoAModel.py deleted file mode 100644 index 7925549fc7d134a98f8e12b6b4b741b03ab63c78..0000000000000000000000000000000000000000 --- a/captioning/models/AoAModel.py +++ /dev/null @@ -1,228 +0,0 @@ -# Implementation for paper 'Attention on Attention for Image Captioning' -# https://arxiv.org/abs/1908.06954 - -# RT: Code from original author's repo: https://github.com/husthuaan/AoANet/ - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import torch -import torch.nn as nn -import torch.nn.functional as F - -from .AttModel import pack_wrapper, AttModel, Attention -from .TransformerModel import LayerNorm, attention, clones, SublayerConnection, PositionwiseFeedForward - -class MultiHeadedDotAttention(nn.Module): - def __init__(self, h, d_model, dropout=0.1, scale=1, project_k_v=1, use_output_layer=1, do_aoa=0, norm_q=0, dropout_aoa=0.3): - super(MultiHeadedDotAttention, self).__init__() - assert d_model * scale % h == 0 - # We assume d_v always equals d_k - self.d_k = d_model * scale // h - self.h = h - - # Do we need to do linear projections on K and V? - self.project_k_v = project_k_v - - # normalize the query? - if norm_q: - self.norm = LayerNorm(d_model) - else: - self.norm = lambda x:x - self.linears = clones(nn.Linear(d_model, d_model * scale), 1 + 2 * project_k_v) - - # output linear layer after the multi-head attention? - self.output_layer = nn.Linear(d_model * scale, d_model) - - # apply aoa after attention? - self.use_aoa = do_aoa - if self.use_aoa: - self.aoa_layer = nn.Sequential(nn.Linear((1 + scale) * d_model, 2 * d_model), nn.GLU()) - # dropout to the input of AoA layer - if dropout_aoa > 0: - self.dropout_aoa = nn.Dropout(p=dropout_aoa) - else: - self.dropout_aoa = lambda x:x - - if self.use_aoa or not use_output_layer: - # AoA doesn't need the output linear layer - del self.output_layer - self.output_layer = lambda x:x - - self.attn = None - self.dropout = nn.Dropout(p=dropout) - - def forward(self, query, value, key, mask=None): - if mask is not None: - if len(mask.size()) == 2: - mask = mask.unsqueeze(-2) - # Same mask applied to all h heads. - mask = mask.unsqueeze(1) - - single_query = 0 - if len(query.size()) == 2: - single_query = 1 - query = query.unsqueeze(1) - - nbatches = query.size(0) - - query = self.norm(query) - - # Do all the linear projections in batch from d_model => h x d_k - if self.project_k_v == 0: - query_ = self.linears[0](query).view(nbatches, -1, self.h, self.d_k).transpose(1, 2) - key_ = key.view(nbatches, -1, self.h, self.d_k).transpose(1, 2) - value_ = value.view(nbatches, -1, self.h, self.d_k).transpose(1, 2) - else: - query_, key_, value_ = \ - [l(x).view(nbatches, -1, self.h, self.d_k).transpose(1, 2) - for l, x in zip(self.linears, (query, key, value))] - - # Apply attention on all the projected vectors in batch. - x, self.attn = attention(query_, key_, value_, mask=mask, - dropout=self.dropout) - - # "Concat" using a view - x = x.transpose(1, 2).contiguous() \ - .view(nbatches, -1, self.h * self.d_k) - - if self.use_aoa: - # Apply AoA - x = self.aoa_layer(self.dropout_aoa(torch.cat([x, query], -1))) - x = self.output_layer(x) - - if single_query: - query = query.squeeze(1) - x = x.squeeze(1) - return x - -class AoA_Refiner_Layer(nn.Module): - def __init__(self, size, self_attn, feed_forward, dropout): - super(AoA_Refiner_Layer, self).__init__() - self.self_attn = self_attn - self.feed_forward = feed_forward - self.use_ff = 0 - if self.feed_forward is not None: - self.use_ff = 1 - self.sublayer = clones(SublayerConnection(size, dropout), 1+self.use_ff) - self.size = size - - def forward(self, x, mask): - x = self.sublayer[0](x, lambda x: self.self_attn(x, x, x, mask)) - return self.sublayer[-1](x, self.feed_forward) if self.use_ff else x - -class AoA_Refiner_Core(nn.Module): - def __init__(self, opt): - super(AoA_Refiner_Core, self).__init__() - attn = MultiHeadedDotAttention(opt.num_heads, opt.rnn_size, project_k_v=1, scale=opt.multi_head_scale, do_aoa=opt.refine_aoa, norm_q=0, dropout_aoa=getattr(opt, 'dropout_aoa', 0.3)) - layer = AoA_Refiner_Layer(opt.rnn_size, attn, PositionwiseFeedForward(opt.rnn_size, 2048, 0.1) if opt.use_ff else None, 0.1) - self.layers = clones(layer, 6) - self.norm = LayerNorm(layer.size) - - def forward(self, x, mask): - for layer in self.layers: - x = layer(x, mask) - return self.norm(x) - -class AoA_Decoder_Core(nn.Module): - def __init__(self, opt): - super(AoA_Decoder_Core, self).__init__() - self.drop_prob_lm = opt.drop_prob_lm - self.d_model = opt.rnn_size - self.use_multi_head = opt.use_multi_head - self.multi_head_scale = opt.multi_head_scale - self.use_ctx_drop = getattr(opt, 'ctx_drop', 0) - self.out_res = getattr(opt, 'out_res', 0) - self.decoder_type = getattr(opt, 'decoder_type', 'AoA') - self.att_lstm = nn.LSTMCell(opt.input_encoding_size + opt.rnn_size, opt.rnn_size) # we, fc, h^2_t-1 - self.out_drop = nn.Dropout(self.drop_prob_lm) - - if self.decoder_type == 'AoA': - # AoA layer - self.att2ctx = nn.Sequential(nn.Linear(self.d_model * opt.multi_head_scale + opt.rnn_size, 2 * opt.rnn_size), nn.GLU()) - elif self.decoder_type == 'LSTM': - # LSTM layer - self.att2ctx = nn.LSTMCell(self.d_model * opt.multi_head_scale + opt.rnn_size, opt.rnn_size) - else: - # Base linear layer - self.att2ctx = nn.Sequential(nn.Linear(self.d_model * opt.multi_head_scale + opt.rnn_size, opt.rnn_size), nn.ReLU()) - - # if opt.use_multi_head == 1: # TODO, not implemented for now - # self.attention = MultiHeadedAddAttention(opt.num_heads, opt.d_model, scale=opt.multi_head_scale) - if opt.use_multi_head == 2: - self.attention = MultiHeadedDotAttention(opt.num_heads, opt.rnn_size, project_k_v=0, scale=opt.multi_head_scale, use_output_layer=0, do_aoa=0, norm_q=1) - else: - self.attention = Attention(opt) - - if self.use_ctx_drop: - self.ctx_drop = nn.Dropout(self.drop_prob_lm) - else: - self.ctx_drop = lambda x :x - - def forward(self, xt, mean_feats, att_feats, p_att_feats, state, att_masks=None): - # state[0][1] is the context vector at the last step - h_att, c_att = self.att_lstm(torch.cat([xt, mean_feats + self.ctx_drop(state[0][1])], 1), (state[0][0], state[1][0])) - - if self.use_multi_head == 2: - att = self.attention(h_att, p_att_feats.narrow(2, 0, self.multi_head_scale * self.d_model), p_att_feats.narrow(2, self.multi_head_scale * self.d_model, self.multi_head_scale * self.d_model), att_masks) - else: - att = self.attention(h_att, att_feats, p_att_feats, att_masks) - - ctx_input = torch.cat([att, h_att], 1) - if self.decoder_type == 'LSTM': - output, c_logic = self.att2ctx(ctx_input, (state[0][1], state[1][1])) - state = (torch.stack((h_att, output)), torch.stack((c_att, c_logic))) - else: - output = self.att2ctx(ctx_input) - # save the context vector to state[0][1] - state = (torch.stack((h_att, output)), torch.stack((c_att, state[1][1]))) - - if self.out_res: - # add residual connection - output = output + h_att - - output = self.out_drop(output) - return output, state - -class AoAModel(AttModel): - def __init__(self, opt): - super(AoAModel, self).__init__(opt) - self.num_layers = 2 - # mean pooling - self.use_mean_feats = getattr(opt, 'mean_feats', 1) - if opt.use_multi_head == 2: - del self.ctx2att - self.ctx2att = nn.Linear(opt.rnn_size, 2 * opt.multi_head_scale * opt.rnn_size) - - if self.use_mean_feats: - del self.fc_embed - if opt.refine: - self.refiner = AoA_Refiner_Core(opt) - else: - self.refiner = lambda x,y : x - self.core = AoA_Decoder_Core(opt) - - self.d_model = getattr(opt, 'd_model', opt.input_encoding_size) - - - def _prepare_feature(self, fc_feats, att_feats, att_masks): - att_feats, att_masks = self.clip_att(att_feats, att_masks) - - # embed att feats - att_feats = pack_wrapper(self.att_embed, att_feats, att_masks) - att_feats = self.refiner(att_feats, att_masks) - - if self.use_mean_feats: - # meaning pooling - if att_masks is None: - mean_feats = torch.mean(att_feats, dim=1) - else: - mean_feats = (torch.sum(att_feats * att_masks.unsqueeze(-1), 1) / torch.sum(att_masks.unsqueeze(-1), 1)) - else: - mean_feats = self.fc_embed(fc_feats) - - # Project the attention feats first to reduce memory and computation. - p_att_feats = self.ctx2att(att_feats) - - return mean_feats, att_feats, p_att_feats, att_masks \ No newline at end of file diff --git a/captioning/models/AttEnsemble.py b/captioning/models/AttEnsemble.py deleted file mode 100644 index 19e88e2ace19e4a73fe6fcb1024bd584d77a38fa..0000000000000000000000000000000000000000 --- a/captioning/models/AttEnsemble.py +++ /dev/null @@ -1,90 +0,0 @@ -# This file is the implementation for ensemble evaluation. - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import numpy as np -import torch -import torch.nn as nn -import torch.nn.functional as F -from torch.autograd import * - -from .CaptionModel import CaptionModel -from .AttModel import pack_wrapper, AttModel - -class AttEnsemble(AttModel): - def __init__(self, models, weights=None): - CaptionModel.__init__(self) - # super(AttEnsemble, self).__init__() - - self.models = nn.ModuleList(models) - self.vocab_size = models[0].vocab_size - self.seq_length = models[0].seq_length - self.bad_endings_ix = models[0].bad_endings_ix - self.ss_prob = 0 - weights = weights or [1.0] * len(self.models) - self.register_buffer('weights', torch.tensor(weights)) - - def init_hidden(self, batch_size): - state = [m.init_hidden(batch_size) for m in self.models] - return self.pack_state(state) - - def pack_state(self, state): - self.state_lengths = [len(_) for _ in state] - return sum([list(_) for _ in state], []) - - def unpack_state(self, state): - out = [] - for l in self.state_lengths: - out.append(state[:l]) - state = state[l:] - return out - - def embed(self, it): - return [m.embed(it) for m in self.models] - - def core(self, *args): - return zip(*[m.core(*_) for m, _ in zip(self.models, zip(*args))]) - - def get_logprobs_state(self, it, tmp_fc_feats, tmp_att_feats, tmp_p_att_feats, tmp_att_masks, state, output_logsoftmax=1): - # 'it' contains a word index - xt = self.embed(it) - - state = self.unpack_state(state) - output, state = self.core(xt, tmp_fc_feats, tmp_att_feats, tmp_p_att_feats, state, tmp_att_masks) - logprobs = torch.stack([F.softmax(m.logit(output[i]), dim=1) for i,m in enumerate(self.models)], 2).mul(self.weights).div(self.weights.sum()).sum(-1).log() - - return logprobs, self.pack_state(state) - - def _prepare_feature(self, *args): - return tuple(zip(*[m._prepare_feature(*args) for m in self.models])) - - def _old_sample_beam(self, fc_feats, att_feats, att_masks=None, opt={}): - beam_size = opt.get('beam_size', 10) - batch_size = fc_feats.size(0) - - fc_feats, att_feats, p_att_feats, att_masks = self._prepare_feature(fc_feats, att_feats, att_masks) - - assert beam_size <= self.vocab_size + 1, 'lets assume this for now, otherwise this corner case causes a few headaches down the road. can be dealt with in future if needed' - seq = torch.LongTensor(self.seq_length, batch_size).zero_() - seqLogprobs = torch.FloatTensor(self.seq_length, batch_size, self.vocab_size + 1) - # lets process every image independently for now, for simplicity - - self.done_beams = [[] for _ in range(batch_size)] - for k in range(batch_size): - state = self.init_hidden(beam_size) - tmp_fc_feats = [fc_feats[i][k:k+1].expand(beam_size, fc_feats[i].size(1)) for i,m in enumerate(self.models)] - tmp_att_feats = [att_feats[i][k:k+1].expand(*((beam_size,)+att_feats[i].size()[1:])).contiguous() for i,m in enumerate(self.models)] - tmp_p_att_feats = [p_att_feats[i][k:k+1].expand(*((beam_size,)+p_att_feats[i].size()[1:])).contiguous() for i,m in enumerate(self.models)] - tmp_att_masks = [att_masks[i][k:k+1].expand(*((beam_size,)+att_masks[i].size()[1:])).contiguous() if att_masks[i] is not None else att_masks[i] for i,m in enumerate(self.models)] - - it = fc_feats[0].data.new(beam_size).long().zero_() - logprobs, state = self.get_logprobs_state(it, tmp_fc_feats, tmp_att_feats, tmp_p_att_feats, tmp_att_masks, state) - - self.done_beams[k] = self.old_beam_search(state, logprobs, tmp_fc_feats, tmp_att_feats, tmp_p_att_feats, tmp_att_masks, opt=opt) - seq[:, k] = self.done_beams[k][0]['seq'] # the first beam has highest cumulative score - seqLogprobs[:, k] = self.done_beams[k][0]['logps'] - # return the samples and their log likelihoods - return seq.transpose(0, 1), seqLogprobs.transpose(0, 1) - # return the samples and their log likelihoods diff --git a/captioning/models/AttModel.py b/captioning/models/AttModel.py deleted file mode 100644 index 3dc4e5b7a78c4affbfba4044ca8c96c30b26e36a..0000000000000000000000000000000000000000 --- a/captioning/models/AttModel.py +++ /dev/null @@ -1,969 +0,0 @@ -# This file contains Att2in2, AdaAtt, AdaAttMO, UpDown model - -# AdaAtt is from Knowing When to Look: Adaptive Attention via A Visual Sentinel for Image Captioning -# https://arxiv.org/abs/1612.01887 -# AdaAttMO is a modified version with maxout lstm - -# Att2in is from Self-critical Sequence Training for Image Captioning -# https://arxiv.org/abs/1612.00563 -# In this file we only have Att2in2, which is a slightly different version of att2in, -# in which the img feature embedding and word embedding is the same as what in adaatt. - -# UpDown is from Bottom-Up and Top-Down Attention for Image Captioning and VQA -# https://arxiv.org/abs/1707.07998 -# However, it may not be identical to the author's architecture. - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import numpy as np -import torch -import torch.nn as nn -import torch.nn.functional as F -from . import utils -from torch.nn.utils.rnn import PackedSequence, pack_padded_sequence, pad_packed_sequence - -from .CaptionModel import CaptionModel - -bad_endings = ['a','an','the','in','for','at','of','with','before','after','on','upon','near','to','is','are','am'] -bad_endings += ['the'] - -def sort_pack_padded_sequence(input, lengths): - sorted_lengths, indices = torch.sort(lengths, descending=True) - # tmp = pack_padded_sequence(input[indices], sorted_lengths, batch_first=True) - tmp = pack_padded_sequence(input[indices], sorted_lengths.cpu(), batch_first=True) - inv_ix = indices.clone() - inv_ix[indices] = torch.arange(0,len(indices)).type_as(inv_ix) - return tmp, inv_ix - -def pad_unsort_packed_sequence(input, inv_ix): - tmp, _ = pad_packed_sequence(input, batch_first=True) - tmp = tmp[inv_ix] - return tmp - -def pack_wrapper(module, att_feats, att_masks): - if att_masks is not None: - packed, inv_ix = sort_pack_padded_sequence(att_feats, att_masks.data.long().sum(1)) - return pad_unsort_packed_sequence(PackedSequence(module(packed[0]), packed[1]), inv_ix) - else: - return module(att_feats) - -class AttModel(CaptionModel): - def __init__(self, opt): - super(AttModel, self).__init__() - self.vocab_size = opt.vocab_size - self.input_encoding_size = opt.input_encoding_size - #self.rnn_type = opt.rnn_type - self.rnn_size = opt.rnn_size - self.num_layers = opt.num_layers - self.drop_prob_lm = opt.drop_prob_lm - self.seq_length = getattr(opt, 'max_length', 20) or opt.seq_length # maximum sample length - self.fc_feat_size = opt.fc_feat_size - self.att_feat_size = opt.att_feat_size - self.att_hid_size = opt.att_hid_size - - self.bos_idx = getattr(opt, 'bos_idx', 0) - self.eos_idx = getattr(opt, 'eos_idx', 0) - self.pad_idx = getattr(opt, 'pad_idx', 0) - - self.use_bn = getattr(opt, 'use_bn', 0) - - self.ss_prob = 0.0 # Schedule sampling probability - - self.embed = nn.Sequential(nn.Embedding(self.vocab_size + 1, self.input_encoding_size), - nn.ReLU(), - nn.Dropout(self.drop_prob_lm)) - self.fc_embed = nn.Sequential(nn.Linear(self.fc_feat_size, self.rnn_size), - nn.ReLU(), - nn.Dropout(self.drop_prob_lm)) - self.att_embed = nn.Sequential(*( - ((nn.BatchNorm1d(self.att_feat_size),) if self.use_bn else ())+ - (nn.Linear(self.att_feat_size, self.rnn_size), - nn.ReLU(), - nn.Dropout(self.drop_prob_lm))+ - ((nn.BatchNorm1d(self.rnn_size),) if self.use_bn==2 else ()))) - - self.logit_layers = getattr(opt, 'logit_layers', 1) - if self.logit_layers == 1: - self.logit = nn.Linear(self.rnn_size, self.vocab_size + 1) - else: - self.logit = [[nn.Linear(self.rnn_size, self.rnn_size), nn.ReLU(), nn.Dropout(0.5)] for _ in range(opt.logit_layers - 1)] - self.logit = nn.Sequential(*(reduce(lambda x,y:x+y, self.logit) + [nn.Linear(self.rnn_size, self.vocab_size + 1)])) - self.ctx2att = nn.Linear(self.rnn_size, self.att_hid_size) - - # For remove bad endding - self.vocab = opt.vocab - self.bad_endings_ix = [int(k) for k,v in self.vocab.items() if v in bad_endings] - - def init_hidden(self, bsz): - weight = self.logit.weight \ - if hasattr(self.logit, "weight") \ - else self.logit[0].weight - return (weight.new_zeros(self.num_layers, bsz, self.rnn_size), - weight.new_zeros(self.num_layers, bsz, self.rnn_size)) - - def clip_att(self, att_feats, att_masks): - # Clip the length of att_masks and att_feats to the maximum length - if att_masks is not None: - max_len = att_masks.data.long().sum(1).max() - att_feats = att_feats[:, :max_len].contiguous() - att_masks = att_masks[:, :max_len].contiguous() - return att_feats, att_masks - - def _prepare_feature(self, fc_feats, att_feats, att_masks): - att_feats, att_masks = self.clip_att(att_feats, att_masks) - - # embed fc and att feats - fc_feats = self.fc_embed(fc_feats) - att_feats = pack_wrapper(self.att_embed, att_feats, att_masks) - - # Project the attention feats first to reduce memory and computation comsumptions. - p_att_feats = self.ctx2att(att_feats) - - return fc_feats, att_feats, p_att_feats, att_masks - - def _forward(self, fc_feats, att_feats, seq, att_masks=None): - batch_size = fc_feats.size(0) - if seq.ndim == 3: # B * seq_per_img * seq_len - seq = seq.reshape(-1, seq.shape[2]) - seq_per_img = seq.shape[0] // batch_size - state = self.init_hidden(batch_size*seq_per_img) - - outputs = fc_feats.new_zeros(batch_size*seq_per_img, seq.size(1), self.vocab_size+1) - - # Prepare the features - p_fc_feats, p_att_feats, pp_att_feats, p_att_masks = self._prepare_feature(fc_feats, att_feats, att_masks) - # pp_att_feats is used for attention, we cache it in advance to reduce computation cost - - if seq_per_img > 1: - p_fc_feats, p_att_feats, pp_att_feats, p_att_masks = utils.repeat_tensors(seq_per_img, - [p_fc_feats, p_att_feats, pp_att_feats, p_att_masks] - ) - - for i in range(seq.size(1)): - if self.training and i >= 1 and self.ss_prob > 0.0: # otherwiste no need to sample - sample_prob = fc_feats.new(batch_size*seq_per_img).uniform_(0, 1) - sample_mask = sample_prob < self.ss_prob - if sample_mask.sum() == 0: - it = seq[:, i].clone() - else: - sample_ind = sample_mask.nonzero().view(-1) - it = seq[:, i].data.clone() - prob_prev = torch.exp(outputs[:, i-1].detach()) # fetch prev distribution: shape Nx(M+1) - it.index_copy_(0, sample_ind, torch.multinomial(prob_prev, 1).view(-1).index_select(0, sample_ind)) - else: - it = seq[:, i].clone() - # break if all the sequences end - if i >= 1 and seq[:, i].sum() == 0: - break - - output, state = self.get_logprobs_state(it, p_fc_feats, p_att_feats, pp_att_feats, p_att_masks, state) - outputs[:, i] = output - - return outputs - - def get_logprobs_state(self, it, fc_feats, att_feats, p_att_feats, att_masks, state, output_logsoftmax=1): - # 'it' contains a word index - xt = self.embed(it) - - output, state = self.core(xt, fc_feats, att_feats, p_att_feats, state, att_masks) - if output_logsoftmax: - logprobs = F.log_softmax(self.logit(output), dim=1) - else: - logprobs = self.logit(output) - - return logprobs, state - - def _old_sample_beam(self, fc_feats, att_feats, att_masks=None, opt={}): - beam_size = opt.get('beam_size', 10) - group_size = opt.get('group_size', 1) - sample_n = opt.get('sample_n', 10) - # when sample_n == beam_size then each beam is a sample. - assert sample_n == 1 or sample_n == beam_size // group_size, 'when beam search, sample_n == 1 or beam search' - batch_size = fc_feats.size(0) - - p_fc_feats, p_att_feats, pp_att_feats, p_att_masks = self._prepare_feature(fc_feats, att_feats, att_masks) - - assert beam_size <= self.vocab_size + 1, 'lets assume this for now, otherwise this corner case causes a few headaches down the road. can be dealt with in future if needed' - seq = fc_feats.new_full((batch_size*sample_n, self.seq_length), self.pad_idx, dtype=torch.long) - seqLogprobs = fc_feats.new_zeros(batch_size*sample_n, self.seq_length, self.vocab_size + 1) - # lets process every image independently for now, for simplicity - - self.done_beams = [[] for _ in range(batch_size)] - for k in range(batch_size): - state = self.init_hidden(beam_size) - tmp_fc_feats, tmp_att_feats, tmp_p_att_feats, tmp_att_masks = utils.repeat_tensors(beam_size, - [p_fc_feats[k:k+1], p_att_feats[k:k+1], pp_att_feats[k:k+1], p_att_masks[k:k+1] if att_masks is not None else None] - ) - - for t in range(1): - if t == 0: # input - it = fc_feats.new_full([beam_size], self.bos_idx, dtype=torch.long) - - logprobs, state = self.get_logprobs_state(it, tmp_fc_feats, tmp_att_feats, tmp_p_att_feats, tmp_att_masks, state) - - self.done_beams[k] = self.old_beam_search(state, logprobs, tmp_fc_feats, tmp_att_feats, tmp_p_att_feats, tmp_att_masks, opt=opt) - if sample_n == beam_size: - for _n in range(sample_n): - seq[k*sample_n+_n, :] = self.done_beams[k][_n]['seq'] - seqLogprobs[k*sample_n+_n, :] = self.done_beams[k][_n]['logps'] - else: - seq[k, :] = self.done_beams[k][0]['seq'] # the first beam has highest cumulative score - seqLogprobs[k, :] = self.done_beams[k][0]['logps'] - # return the samples and their log likelihoods - return seq, seqLogprobs - - - def _sample_beam(self, fc_feats, att_feats, att_masks=None, opt={}): - beam_size = opt.get('beam_size', 10) - group_size = opt.get('group_size', 1) - sample_n = opt.get('sample_n', 10) - # when sample_n == beam_size then each beam is a sample. - assert sample_n == 1 or sample_n == beam_size // group_size, 'when beam search, sample_n == 1 or beam search' - batch_size = fc_feats.size(0) - - p_fc_feats, p_att_feats, pp_att_feats, p_att_masks = self._prepare_feature(fc_feats, att_feats, att_masks) - - assert beam_size <= self.vocab_size + 1, 'lets assume this for now, otherwise this corner case causes a few headaches down the road. can be dealt with in future if needed' - seq = fc_feats.new_full((batch_size*sample_n, self.seq_length), self.pad_idx, dtype=torch.long) - seqLogprobs = fc_feats.new_zeros(batch_size*sample_n, self.seq_length, self.vocab_size + 1) - # lets process every image independently for now, for simplicity - - self.done_beams = [[] for _ in range(batch_size)] - - state = self.init_hidden(batch_size) - - # first step, feed bos - it = fc_feats.new_full([batch_size], self.bos_idx, dtype=torch.long) - logprobs, state = self.get_logprobs_state(it, p_fc_feats, p_att_feats, pp_att_feats, p_att_masks, state) - - p_fc_feats, p_att_feats, pp_att_feats, p_att_masks = utils.repeat_tensors(beam_size, - [p_fc_feats, p_att_feats, pp_att_feats, p_att_masks] - ) - self.done_beams = self.beam_search(state, logprobs, p_fc_feats, p_att_feats, pp_att_feats, p_att_masks, opt=opt) - for k in range(batch_size): - if sample_n == beam_size: - for _n in range(sample_n): - seq_len = self.done_beams[k][_n]['seq'].shape[0] - seq[k*sample_n+_n, :seq_len] = self.done_beams[k][_n]['seq'] - seqLogprobs[k*sample_n+_n, :seq_len] = self.done_beams[k][_n]['logps'] - else: - seq_len = self.done_beams[k][0]['seq'].shape[0] - seq[k, :seq_len] = self.done_beams[k][0]['seq'] # the first beam has highest cumulative score - seqLogprobs[k, :seq_len] = self.done_beams[k][0]['logps'] - # return the samples and their log likelihoods - return seq, seqLogprobs - - def _sample(self, fc_feats, att_feats, att_masks=None, opt={}): - - sample_method = opt.get('sample_method', 'greedy') - beam_size = opt.get('beam_size', 1) - temperature = opt.get('temperature', 1.0) - sample_n = int(opt.get('sample_n', 1)) - group_size = opt.get('group_size', 1) - output_logsoftmax = opt.get('output_logsoftmax', 1) - decoding_constraint = opt.get('decoding_constraint', 0) - block_trigrams = opt.get('block_trigrams', 0) - remove_bad_endings = opt.get('remove_bad_endings', 0) - if beam_size > 1 and sample_method in ['greedy', 'beam_search']: - return self._sample_beam(fc_feats, att_feats, att_masks, opt) - if group_size > 1: - return self._diverse_sample(fc_feats, att_feats, att_masks, opt) - - batch_size = fc_feats.size(0) - state = self.init_hidden(batch_size*sample_n) - - p_fc_feats, p_att_feats, pp_att_feats, p_att_masks = self._prepare_feature(fc_feats, att_feats, att_masks) - - if sample_n > 1: - p_fc_feats, p_att_feats, pp_att_feats, p_att_masks = utils.repeat_tensors(sample_n, - [p_fc_feats, p_att_feats, pp_att_feats, p_att_masks] - ) - - trigrams = [] # will be a list of batch_size dictionaries - - seq = fc_feats.new_full((batch_size*sample_n, self.seq_length), self.pad_idx, dtype=torch.long) - seqLogprobs = fc_feats.new_zeros(batch_size*sample_n, self.seq_length, self.vocab_size + 1) - for t in range(self.seq_length + 1): - if t == 0: # input - it = fc_feats.new_full([batch_size*sample_n], self.bos_idx, dtype=torch.long) - - logprobs, state = self.get_logprobs_state(it, p_fc_feats, p_att_feats, pp_att_feats, p_att_masks, state, output_logsoftmax=output_logsoftmax) - - if decoding_constraint and t > 0: - tmp = logprobs.new_zeros(logprobs.size()) - tmp.scatter_(1, seq[:,t-1].data.unsqueeze(1), float('-inf')) - logprobs = logprobs + tmp - - if remove_bad_endings and t > 0: - tmp = logprobs.new_zeros(logprobs.size()) - prev_bad = np.isin(seq[:,t-1].data.cpu().numpy(), self.bad_endings_ix) - # Make it impossible to generate bad_endings - tmp[torch.from_numpy(prev_bad.astype('uint8')), 0] = float('-inf') - logprobs = logprobs + tmp - - # Mess with trigrams - # Copy from https://github.com/lukemelas/image-paragraph-captioning - if block_trigrams and t >= 3: - # Store trigram generated at last step - prev_two_batch = seq[:,t-3:t-1] - for i in range(batch_size): # = seq.size(0) - prev_two = (prev_two_batch[i][0].item(), prev_two_batch[i][1].item()) - current = seq[i][t-1] - if t == 3: # initialize - trigrams.append({prev_two: [current]}) # {LongTensor: list containing 1 int} - elif t > 3: - if prev_two in trigrams[i]: # add to list - trigrams[i][prev_two].append(current) - else: # create list - trigrams[i][prev_two] = [current] - # Block used trigrams at next step - prev_two_batch = seq[:,t-2:t] - mask = torch.zeros(logprobs.size(), requires_grad=False).to(logprobs.device) # batch_size x vocab_size - for i in range(batch_size): - prev_two = (prev_two_batch[i][0].item(), prev_two_batch[i][1].item()) - if prev_two in trigrams[i]: - for j in trigrams[i][prev_two]: - mask[i,j] += 1 - # Apply mask to log probs - #logprobs = logprobs - (mask * 1e9) - alpha = 2.0 # = 4 - logprobs = logprobs + (mask * -0.693 * alpha) # ln(1/2) * alpha (alpha -> infty works best) - - # sample the next word - if t == self.seq_length: # skip if we achieve maximum length - break - it, sampleLogprobs = self.sample_next_word(logprobs, sample_method, temperature) - - # stop when all finished - if t == 0: - unfinished = it != self.eos_idx - else: - it[~unfinished] = self.pad_idx # This allows eos_idx not being overwritten to 0 - logprobs = logprobs * unfinished.unsqueeze(1).to(logprobs) - unfinished = unfinished & (it != self.eos_idx) - seq[:,t] = it - seqLogprobs[:,t] = logprobs - # quit loop if all sequences have finished - if unfinished.sum() == 0: - break - - return seq, seqLogprobs - - def _diverse_sample(self, fc_feats, att_feats, att_masks=None, opt={}): - - sample_method = opt.get('sample_method', 'greedy') - beam_size = opt.get('beam_size', 1) - temperature = opt.get('temperature', 1.0) - group_size = opt.get('group_size', 1) - diversity_lambda = opt.get('diversity_lambda', 0.5) - decoding_constraint = opt.get('decoding_constraint', 0) - block_trigrams = opt.get('block_trigrams', 0) - remove_bad_endings = opt.get('remove_bad_endings', 0) - - batch_size = fc_feats.size(0) - state = self.init_hidden(batch_size) - - p_fc_feats, p_att_feats, pp_att_feats, p_att_masks = self._prepare_feature(fc_feats, att_feats, att_masks) - - trigrams_table = [[] for _ in range(group_size)] # will be a list of batch_size dictionaries - - seq_table = [fc_feats.new_full((batch_size, self.seq_length), self.pad_idx, dtype=torch.long) for _ in range(group_size)] - seqLogprobs_table = [fc_feats.new_zeros(batch_size, self.seq_length) for _ in range(group_size)] - state_table = [self.init_hidden(batch_size) for _ in range(group_size)] - - for tt in range(self.seq_length + group_size): - for divm in range(group_size): - t = tt - divm - seq = seq_table[divm] - seqLogprobs = seqLogprobs_table[divm] - trigrams = trigrams_table[divm] - if t >= 0 and t <= self.seq_length-1: - if t == 0: # input - it = fc_feats.new_full([batch_size], self.bos_idx, dtype=torch.long) - else: - it = seq[:, t-1] # changed - - logprobs, state_table[divm] = self.get_logprobs_state(it, p_fc_feats, p_att_feats, pp_att_feats, p_att_masks, state_table[divm]) # changed - logprobs = F.log_softmax(logprobs / temperature, dim=-1) - - # Add diversity - if divm > 0: - unaug_logprobs = logprobs.clone() - for prev_choice in range(divm): - prev_decisions = seq_table[prev_choice][:, t] - logprobs[:, prev_decisions] = logprobs[:, prev_decisions] - diversity_lambda - - if decoding_constraint and t > 0: - tmp = logprobs.new_zeros(logprobs.size()) - tmp.scatter_(1, seq[:,t-1].data.unsqueeze(1), float('-inf')) - logprobs = logprobs + tmp - - if remove_bad_endings and t > 0: - tmp = logprobs.new_zeros(logprobs.size()) - prev_bad = np.isin(seq[:,t-1].data.cpu().numpy(), self.bad_endings_ix) - # Impossible to generate remove_bad_endings - tmp[torch.from_numpy(prev_bad.astype('uint8')), 0] = float('-inf') - logprobs = logprobs + tmp - - # Mess with trigrams - if block_trigrams and t >= 3: - # Store trigram generated at last step - prev_two_batch = seq[:,t-3:t-1] - for i in range(batch_size): # = seq.size(0) - prev_two = (prev_two_batch[i][0].item(), prev_two_batch[i][1].item()) - current = seq[i][t-1] - if t == 3: # initialize - trigrams.append({prev_two: [current]}) # {LongTensor: list containing 1 int} - elif t > 3: - if prev_two in trigrams[i]: # add to list - trigrams[i][prev_two].append(current) - else: # create list - trigrams[i][prev_two] = [current] - # Block used trigrams at next step - prev_two_batch = seq[:,t-2:t] - mask = torch.zeros(logprobs.size(), requires_grad=False).cuda() # batch_size x vocab_size - for i in range(batch_size): - prev_two = (prev_two_batch[i][0].item(), prev_two_batch[i][1].item()) - if prev_two in trigrams[i]: - for j in trigrams[i][prev_two]: - mask[i,j] += 1 - # Apply mask to log probs - #logprobs = logprobs - (mask * 1e9) - alpha = 2.0 # = 4 - logprobs = logprobs + (mask * -0.693 * alpha) # ln(1/2) * alpha (alpha -> infty works best) - - it, sampleLogprobs = self.sample_next_word(logprobs, sample_method, 1) - - # stop when all finished - if t == 0: - unfinished = it != self.eos_idx - else: - unfinished = (seq[:,t-1] != self.pad_idx) & (seq[:,t-1] != self.eos_idx) - it[~unfinished] = self.pad_idx - unfinished = unfinished & (it != self.eos_idx) # changed - seq[:,t] = it - seqLogprobs[:,t] = sampleLogprobs.view(-1) - - return torch.stack(seq_table, 1).reshape(batch_size * group_size, -1), torch.stack(seqLogprobs_table, 1).reshape(batch_size * group_size, -1) - -class AdaAtt_lstm(nn.Module): - def __init__(self, opt, use_maxout=True): - super(AdaAtt_lstm, self).__init__() - self.input_encoding_size = opt.input_encoding_size - #self.rnn_type = opt.rnn_type - self.rnn_size = opt.rnn_size - self.num_layers = opt.num_layers - self.drop_prob_lm = opt.drop_prob_lm - self.fc_feat_size = opt.fc_feat_size - self.att_feat_size = opt.att_feat_size - self.att_hid_size = opt.att_hid_size - - self.use_maxout = use_maxout - - # Build a LSTM - self.w2h = nn.Linear(self.input_encoding_size, (4+(use_maxout==True)) * self.rnn_size) - self.v2h = nn.Linear(self.rnn_size, (4+(use_maxout==True)) * self.rnn_size) - - self.i2h = nn.ModuleList([nn.Linear(self.rnn_size, (4+(use_maxout==True)) * self.rnn_size) for _ in range(self.num_layers - 1)]) - self.h2h = nn.ModuleList([nn.Linear(self.rnn_size, (4+(use_maxout==True)) * self.rnn_size) for _ in range(self.num_layers)]) - - # Layers for getting the fake region - if self.num_layers == 1: - self.r_w2h = nn.Linear(self.input_encoding_size, self.rnn_size) - self.r_v2h = nn.Linear(self.rnn_size, self.rnn_size) - else: - self.r_i2h = nn.Linear(self.rnn_size, self.rnn_size) - self.r_h2h = nn.Linear(self.rnn_size, self.rnn_size) - - - def forward(self, xt, img_fc, state): - - hs = [] - cs = [] - for L in range(self.num_layers): - # c,h from previous timesteps - prev_h = state[0][L] - prev_c = state[1][L] - # the input to this layer - if L == 0: - x = xt - i2h = self.w2h(x) + self.v2h(img_fc) - else: - x = hs[-1] - x = F.dropout(x, self.drop_prob_lm, self.training) - i2h = self.i2h[L-1](x) - - all_input_sums = i2h+self.h2h[L](prev_h) - - sigmoid_chunk = all_input_sums.narrow(1, 0, 3 * self.rnn_size) - sigmoid_chunk = torch.sigmoid(sigmoid_chunk) - # decode the gates - in_gate = sigmoid_chunk.narrow(1, 0, self.rnn_size) - forget_gate = sigmoid_chunk.narrow(1, self.rnn_size, self.rnn_size) - out_gate = sigmoid_chunk.narrow(1, self.rnn_size * 2, self.rnn_size) - # decode the write inputs - if not self.use_maxout: - in_transform = torch.tanh(all_input_sums.narrow(1, 3 * self.rnn_size, self.rnn_size)) - else: - in_transform = all_input_sums.narrow(1, 3 * self.rnn_size, 2 * self.rnn_size) - in_transform = torch.max(\ - in_transform.narrow(1, 0, self.rnn_size), - in_transform.narrow(1, self.rnn_size, self.rnn_size)) - # perform the LSTM update - next_c = forget_gate * prev_c + in_gate * in_transform - # gated cells form the output - tanh_nex_c = torch.tanh(next_c) - next_h = out_gate * tanh_nex_c - if L == self.num_layers-1: - if L == 0: - i2h = self.r_w2h(x) + self.r_v2h(img_fc) - else: - i2h = self.r_i2h(x) - n5 = i2h+self.r_h2h(prev_h) - fake_region = torch.sigmoid(n5) * tanh_nex_c - - cs.append(next_c) - hs.append(next_h) - - # set up the decoder - top_h = hs[-1] - top_h = F.dropout(top_h, self.drop_prob_lm, self.training) - fake_region = F.dropout(fake_region, self.drop_prob_lm, self.training) - - state = (torch.cat([_.unsqueeze(0) for _ in hs], 0), - torch.cat([_.unsqueeze(0) for _ in cs], 0)) - return top_h, fake_region, state - -class AdaAtt_attention(nn.Module): - def __init__(self, opt): - super(AdaAtt_attention, self).__init__() - self.input_encoding_size = opt.input_encoding_size - #self.rnn_type = opt.rnn_type - self.rnn_size = opt.rnn_size - self.drop_prob_lm = opt.drop_prob_lm - self.att_hid_size = opt.att_hid_size - - # fake region embed - self.fr_linear = nn.Sequential( - nn.Linear(self.rnn_size, self.input_encoding_size), - nn.ReLU(), - nn.Dropout(self.drop_prob_lm)) - self.fr_embed = nn.Linear(self.input_encoding_size, self.att_hid_size) - - # h out embed - self.ho_linear = nn.Sequential( - nn.Linear(self.rnn_size, self.input_encoding_size), - nn.Tanh(), - nn.Dropout(self.drop_prob_lm)) - self.ho_embed = nn.Linear(self.input_encoding_size, self.att_hid_size) - - self.alpha_net = nn.Linear(self.att_hid_size, 1) - self.att2h = nn.Linear(self.rnn_size, self.rnn_size) - - def forward(self, h_out, fake_region, conv_feat, conv_feat_embed, att_masks=None): - - # View into three dimensions - att_size = conv_feat.numel() // conv_feat.size(0) // self.rnn_size - conv_feat = conv_feat.view(-1, att_size, self.rnn_size) - conv_feat_embed = conv_feat_embed.view(-1, att_size, self.att_hid_size) - - # view neighbor from bach_size * neighbor_num x rnn_size to bach_size x rnn_size * neighbor_num - fake_region = self.fr_linear(fake_region) - fake_region_embed = self.fr_embed(fake_region) - - h_out_linear = self.ho_linear(h_out) - h_out_embed = self.ho_embed(h_out_linear) - - txt_replicate = h_out_embed.unsqueeze(1).expand(h_out_embed.size(0), att_size + 1, h_out_embed.size(1)) - - img_all = torch.cat([fake_region.view(-1,1,self.input_encoding_size), conv_feat], 1) - img_all_embed = torch.cat([fake_region_embed.view(-1,1,self.input_encoding_size), conv_feat_embed], 1) - - hA = torch.tanh(img_all_embed + txt_replicate) - hA = F.dropout(hA,self.drop_prob_lm, self.training) - - hAflat = self.alpha_net(hA.view(-1, self.att_hid_size)) - PI = F.softmax(hAflat.view(-1, att_size + 1), dim=1) - - if att_masks is not None: - att_masks = att_masks.view(-1, att_size) - PI = PI * torch.cat([att_masks[:,:1], att_masks], 1) # assume one one at the first time step. - PI = PI / PI.sum(1, keepdim=True) - - visAtt = torch.bmm(PI.unsqueeze(1), img_all) - visAttdim = visAtt.squeeze(1) - - atten_out = visAttdim + h_out_linear - - h = torch.tanh(self.att2h(atten_out)) - h = F.dropout(h, self.drop_prob_lm, self.training) - return h - -class AdaAttCore(nn.Module): - def __init__(self, opt, use_maxout=False): - super(AdaAttCore, self).__init__() - self.lstm = AdaAtt_lstm(opt, use_maxout) - self.attention = AdaAtt_attention(opt) - - def forward(self, xt, fc_feats, att_feats, p_att_feats, state, att_masks=None): - h_out, p_out, state = self.lstm(xt, fc_feats, state) - atten_out = self.attention(h_out, p_out, att_feats, p_att_feats, att_masks) - return atten_out, state - -class UpDownCore(nn.Module): - def __init__(self, opt, use_maxout=False): - super(UpDownCore, self).__init__() - self.drop_prob_lm = opt.drop_prob_lm - - self.att_lstm = nn.LSTMCell(opt.input_encoding_size + opt.rnn_size * 2, opt.rnn_size) # we, fc, h^2_t-1 - self.lang_lstm = nn.LSTMCell(opt.rnn_size * 2, opt.rnn_size) # h^1_t, \hat v - self.attention = Attention(opt) - - def forward(self, xt, fc_feats, att_feats, p_att_feats, state, att_masks=None): - prev_h = state[0][-1] - att_lstm_input = torch.cat([prev_h, fc_feats, xt], 1) - - h_att, c_att = self.att_lstm(att_lstm_input, (state[0][0], state[1][0])) - - att = self.attention(h_att, att_feats, p_att_feats, att_masks) - - lang_lstm_input = torch.cat([att, h_att], 1) - # lang_lstm_input = torch.cat([att, F.dropout(h_att, self.drop_prob_lm, self.training)], 1) ????? - - h_lang, c_lang = self.lang_lstm(lang_lstm_input, (state[0][1], state[1][1])) - - output = F.dropout(h_lang, self.drop_prob_lm, self.training) - state = (torch.stack([h_att, h_lang]), torch.stack([c_att, c_lang])) - - return output, state - - -############################################################################ -# Notice: -# StackAtt and DenseAtt are models that I randomly designed. -# They are not related to any paper. -############################################################################ - -from .FCModel import LSTMCore -class StackAttCore(nn.Module): - def __init__(self, opt, use_maxout=False): - super(StackAttCore, self).__init__() - self.drop_prob_lm = opt.drop_prob_lm - - # self.att0 = Attention(opt) - self.att1 = Attention(opt) - self.att2 = Attention(opt) - - opt_input_encoding_size = opt.input_encoding_size - opt.input_encoding_size = opt.input_encoding_size + opt.rnn_size - self.lstm0 = LSTMCore(opt) # att_feat + word_embedding - opt.input_encoding_size = opt.rnn_size * 2 - self.lstm1 = LSTMCore(opt) - self.lstm2 = LSTMCore(opt) - opt.input_encoding_size = opt_input_encoding_size - - # self.emb1 = nn.Linear(opt.rnn_size, opt.rnn_size) - self.emb2 = nn.Linear(opt.rnn_size, opt.rnn_size) - - def forward(self, xt, fc_feats, att_feats, p_att_feats, state, att_masks=None): - # att_res_0 = self.att0(state[0][-1], att_feats, p_att_feats, att_masks) - h_0, state_0 = self.lstm0(torch.cat([xt,fc_feats],1), [state[0][0:1], state[1][0:1]]) - att_res_1 = self.att1(h_0, att_feats, p_att_feats, att_masks) - h_1, state_1 = self.lstm1(torch.cat([h_0,att_res_1],1), [state[0][1:2], state[1][1:2]]) - att_res_2 = self.att2(h_1 + self.emb2(att_res_1), att_feats, p_att_feats, att_masks) - h_2, state_2 = self.lstm2(torch.cat([h_1,att_res_2],1), [state[0][2:3], state[1][2:3]]) - - return h_2, [torch.cat(_, 0) for _ in zip(state_0, state_1, state_2)] - -class DenseAttCore(nn.Module): - def __init__(self, opt, use_maxout=False): - super(DenseAttCore, self).__init__() - self.drop_prob_lm = opt.drop_prob_lm - - # self.att0 = Attention(opt) - self.att1 = Attention(opt) - self.att2 = Attention(opt) - - opt_input_encoding_size = opt.input_encoding_size - opt.input_encoding_size = opt.input_encoding_size + opt.rnn_size - self.lstm0 = LSTMCore(opt) # att_feat + word_embedding - opt.input_encoding_size = opt.rnn_size * 2 - self.lstm1 = LSTMCore(opt) - self.lstm2 = LSTMCore(opt) - opt.input_encoding_size = opt_input_encoding_size - - # self.emb1 = nn.Linear(opt.rnn_size, opt.rnn_size) - self.emb2 = nn.Linear(opt.rnn_size, opt.rnn_size) - - # fuse h_0 and h_1 - self.fusion1 = nn.Sequential(nn.Linear(opt.rnn_size*2, opt.rnn_size), - nn.ReLU(), - nn.Dropout(opt.drop_prob_lm)) - # fuse h_0, h_1 and h_2 - self.fusion2 = nn.Sequential(nn.Linear(opt.rnn_size*3, opt.rnn_size), - nn.ReLU(), - nn.Dropout(opt.drop_prob_lm)) - - def forward(self, xt, fc_feats, att_feats, p_att_feats, state, att_masks=None): - # att_res_0 = self.att0(state[0][-1], att_feats, p_att_feats, att_masks) - h_0, state_0 = self.lstm0(torch.cat([xt,fc_feats],1), [state[0][0:1], state[1][0:1]]) - att_res_1 = self.att1(h_0, att_feats, p_att_feats, att_masks) - h_1, state_1 = self.lstm1(torch.cat([h_0,att_res_1],1), [state[0][1:2], state[1][1:2]]) - att_res_2 = self.att2(h_1 + self.emb2(att_res_1), att_feats, p_att_feats, att_masks) - h_2, state_2 = self.lstm2(torch.cat([self.fusion1(torch.cat([h_0, h_1], 1)),att_res_2],1), [state[0][2:3], state[1][2:3]]) - - return self.fusion2(torch.cat([h_0, h_1, h_2], 1)), [torch.cat(_, 0) for _ in zip(state_0, state_1, state_2)] - -class Attention(nn.Module): - def __init__(self, opt): - super(Attention, self).__init__() - self.rnn_size = opt.rnn_size - self.att_hid_size = opt.att_hid_size - - self.h2att = nn.Linear(self.rnn_size, self.att_hid_size) - self.alpha_net = nn.Linear(self.att_hid_size, 1) - - def forward(self, h, att_feats, p_att_feats, att_masks=None): - # The p_att_feats here is already projected - att_size = att_feats.numel() // att_feats.size(0) // att_feats.size(-1) - att = p_att_feats.view(-1, att_size, self.att_hid_size) - - att_h = self.h2att(h) # batch * att_hid_size - att_h = att_h.unsqueeze(1).expand_as(att) # batch * att_size * att_hid_size - dot = att + att_h # batch * att_size * att_hid_size - dot = torch.tanh(dot) # batch * att_size * att_hid_size - dot = dot.view(-1, self.att_hid_size) # (batch * att_size) * att_hid_size - dot = self.alpha_net(dot) # (batch * att_size) * 1 - dot = dot.view(-1, att_size) # batch * att_size - - weight = F.softmax(dot, dim=1) # batch * att_size - if att_masks is not None: - weight = weight * att_masks.view(-1, att_size).to(weight) - weight = weight / weight.sum(1, keepdim=True) # normalize to 1 - att_feats_ = att_feats.view(-1, att_size, att_feats.size(-1)) # batch * att_size * att_feat_size - att_res = torch.bmm(weight.unsqueeze(1), att_feats_).squeeze(1) # batch * att_feat_size - - return att_res - -class Att2in2Core(nn.Module): - def __init__(self, opt): - super(Att2in2Core, self).__init__() - self.input_encoding_size = opt.input_encoding_size - #self.rnn_type = opt.rnn_type - self.rnn_size = opt.rnn_size - #self.num_layers = opt.num_layers - self.drop_prob_lm = opt.drop_prob_lm - self.fc_feat_size = opt.fc_feat_size - self.att_feat_size = opt.att_feat_size - self.att_hid_size = opt.att_hid_size - - # Build a LSTM - self.a2c = nn.Linear(self.rnn_size, 2 * self.rnn_size) - self.i2h = nn.Linear(self.input_encoding_size, 5 * self.rnn_size) - self.h2h = nn.Linear(self.rnn_size, 5 * self.rnn_size) - self.dropout = nn.Dropout(self.drop_prob_lm) - - self.attention = Attention(opt) - - def forward(self, xt, fc_feats, att_feats, p_att_feats, state, att_masks=None): - att_res = self.attention(state[0][-1], att_feats, p_att_feats, att_masks) - - all_input_sums = self.i2h(xt) + self.h2h(state[0][-1]) - sigmoid_chunk = all_input_sums.narrow(1, 0, 3 * self.rnn_size) - sigmoid_chunk = torch.sigmoid(sigmoid_chunk) - in_gate = sigmoid_chunk.narrow(1, 0, self.rnn_size) - forget_gate = sigmoid_chunk.narrow(1, self.rnn_size, self.rnn_size) - out_gate = sigmoid_chunk.narrow(1, self.rnn_size * 2, self.rnn_size) - - in_transform = all_input_sums.narrow(1, 3 * self.rnn_size, 2 * self.rnn_size) + \ - self.a2c(att_res) - in_transform = torch.max(\ - in_transform.narrow(1, 0, self.rnn_size), - in_transform.narrow(1, self.rnn_size, self.rnn_size)) - next_c = forget_gate * state[1][-1] + in_gate * in_transform - next_h = out_gate * torch.tanh(next_c) - - output = self.dropout(next_h) - state = (next_h.unsqueeze(0), next_c.unsqueeze(0)) - return output, state - -class Att2inCore(Att2in2Core): - def __init__(self, opt): - super(Att2inCore, self).__init__(opt) - del self.a2c - self.a2c = nn.Linear(self.att_feat_size, 2 * self.rnn_size) - -""" -Note this is my attempt to replicate att2all model in self-critical paper. -However, this is not a correct replication actually. Will fix it. -""" -class Att2all2Core(nn.Module): - def __init__(self, opt): - super(Att2all2Core, self).__init__() - self.input_encoding_size = opt.input_encoding_size - #self.rnn_type = opt.rnn_type - self.rnn_size = opt.rnn_size - #self.num_layers = opt.num_layers - self.drop_prob_lm = opt.drop_prob_lm - self.fc_feat_size = opt.fc_feat_size - self.att_feat_size = opt.att_feat_size - self.att_hid_size = opt.att_hid_size - - # Build a LSTM - self.a2h = nn.Linear(self.rnn_size, 5 * self.rnn_size) - self.i2h = nn.Linear(self.input_encoding_size, 5 * self.rnn_size) - self.h2h = nn.Linear(self.rnn_size, 5 * self.rnn_size) - self.dropout = nn.Dropout(self.drop_prob_lm) - - self.attention = Attention(opt) - - def forward(self, xt, fc_feats, att_feats, p_att_feats, state, att_masks=None): - att_res = self.attention(state[0][-1], att_feats, p_att_feats, att_masks) - - all_input_sums = self.i2h(xt) + self.h2h(state[0][-1]) + self.a2h(att_res) - sigmoid_chunk = all_input_sums.narrow(1, 0, 3 * self.rnn_size) - sigmoid_chunk = torch.sigmoid(sigmoid_chunk) - in_gate = sigmoid_chunk.narrow(1, 0, self.rnn_size) - forget_gate = sigmoid_chunk.narrow(1, self.rnn_size, self.rnn_size) - out_gate = sigmoid_chunk.narrow(1, self.rnn_size * 2, self.rnn_size) - - in_transform = all_input_sums.narrow(1, 3 * self.rnn_size, 2 * self.rnn_size) - in_transform = torch.max(\ - in_transform.narrow(1, 0, self.rnn_size), - in_transform.narrow(1, self.rnn_size, self.rnn_size)) - next_c = forget_gate * state[1][-1] + in_gate * in_transform - next_h = out_gate * torch.tanh(next_c) - - output = self.dropout(next_h) - state = (next_h.unsqueeze(0), next_c.unsqueeze(0)) - return output, state - -class AdaAttModel(AttModel): - def __init__(self, opt): - super(AdaAttModel, self).__init__(opt) - self.core = AdaAttCore(opt) - -# AdaAtt with maxout lstm -class AdaAttMOModel(AttModel): - def __init__(self, opt): - super(AdaAttMOModel, self).__init__(opt) - self.core = AdaAttCore(opt, True) - -class Att2in2Model(AttModel): - def __init__(self, opt): - super(Att2in2Model, self).__init__(opt) - self.core = Att2in2Core(opt) - delattr(self, 'fc_embed') - self.fc_embed = lambda x : x - -class Att2all2Model(AttModel): - def __init__(self, opt): - super(Att2all2Model, self).__init__(opt) - self.core = Att2all2Core(opt) - delattr(self, 'fc_embed') - self.fc_embed = lambda x : x - -class UpDownModel(AttModel): - def __init__(self, opt): - super(UpDownModel, self).__init__(opt) - self.num_layers = 2 - self.core = UpDownCore(opt) - -class StackAttModel(AttModel): - def __init__(self, opt): - super(StackAttModel, self).__init__(opt) - self.num_layers = 3 - self.core = StackAttCore(opt) - -class DenseAttModel(AttModel): - def __init__(self, opt): - super(DenseAttModel, self).__init__(opt) - self.num_layers = 3 - self.core = DenseAttCore(opt) - -class Att2inModel(AttModel): - def __init__(self, opt): - super(Att2inModel, self).__init__(opt) - del self.embed, self.fc_embed, self.att_embed - self.embed = nn.Embedding(self.vocab_size + 1, self.input_encoding_size) - self.fc_embed = self.att_embed = lambda x: x - del self.ctx2att - self.ctx2att = nn.Linear(self.att_feat_size, self.att_hid_size) - self.core = Att2inCore(opt) - self.init_weights() - - def init_weights(self): - initrange = 0.1 - self.embed.weight.data.uniform_(-initrange, initrange) - self.logit.bias.data.fill_(0) - self.logit.weight.data.uniform_(-initrange, initrange) - - -class NewFCModel(AttModel): - def __init__(self, opt): - super(NewFCModel, self).__init__(opt) - self.fc_embed = nn.Linear(self.fc_feat_size, self.input_encoding_size) - self.embed = nn.Embedding(self.vocab_size + 1, self.input_encoding_size) - self._core = LSTMCore(opt) - delattr(self, 'att_embed') - self.att_embed = lambda x : x - delattr(self, 'ctx2att') - self.ctx2att = lambda x: x - - def core(self, xt, fc_feats, att_feats, p_att_feats, state, att_masks): - # Step 0, feed the input image - # if (self.training and state[0].is_leaf) or \ - # (not self.training and state[0].sum() == 0): - # _, state = self._core(fc_feats, state) - # three cases - # normal mle training - # Sample - # beam search (diverse beam search) - # fixed captioning module. - is_first_step = (state[0]==0).all(2).all(0) # size: B - if is_first_step.all(): - _, state = self._core(fc_feats, state) - elif is_first_step.any(): - # This is mostly for diverse beam search I think - new_state = [torch.zeros_like(_) for _ in state] - new_state[0][:, ~is_first_step] = state[0][:, ~is_first_step] - new_state[1][:, ~is_first_step] = state[1][:, ~is_first_step] - _, state = self._core(fc_feats, state) - new_state[0][:, is_first_step] = state[0][:, is_first_step] - new_state[1][:, is_first_step] = state[1][:, is_first_step] - state = new_state - # if (state[0]==0).all(): - # # Let's forget about diverse beam search first - # _, state = self._core(fc_feats, state) - return self._core(xt, state) - - def _prepare_feature(self, fc_feats, att_feats, att_masks): - fc_feats = self.fc_embed(fc_feats) - - return fc_feats, att_feats, att_feats, att_masks - - -class LMModel(AttModel): - def __init__(self, opt): - super(LMModel, self).__init__(opt) - delattr(self, 'fc_embed') - self.fc_embed = lambda x: x.new_zeros(x.shape[0], self.input_encoding_size) - self.embed = nn.Embedding(self.vocab_size + 1, self.input_encoding_size) - self._core = LSTMCore(opt) - delattr(self, 'att_embed') - self.att_embed = lambda x : x - delattr(self, 'ctx2att') - self.ctx2att = lambda x: x - - def core(self, xt, fc_feats, att_feats, p_att_feats, state, att_masks): - if (state[0]==0).all(): - # Let's forget about diverse beam search first - _, state = self._core(fc_feats, state) - return self._core(xt, state) - - def _prepare_feature(self, fc_feats, att_feats, att_masks): - fc_feats = self.fc_embed(fc_feats) - - return fc_feats, None, None, None \ No newline at end of file diff --git a/captioning/models/BertCapModel.py b/captioning/models/BertCapModel.py deleted file mode 100644 index 3a7ccec2c40b2a171393059ec1a3af511163c246..0000000000000000000000000000000000000000 --- a/captioning/models/BertCapModel.py +++ /dev/null @@ -1,104 +0,0 @@ -""" -BertCapModel is using huggingface transformer bert model as seq2seq model. - -The result is not as goog as original transformer. -""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import torch -import torch.nn as nn -import torch.nn.functional as F - -import copy -import math -import numpy as np - -from .CaptionModel import CaptionModel -from .AttModel import sort_pack_padded_sequence, pad_unsort_packed_sequence, pack_wrapper, AttModel -try: - from transformers import BertModel, BertConfig -except: - print('Hugginface transformers not installed; please visit https://github.com/huggingface/transformers') -from .TransformerModel import subsequent_mask, TransformerModel, Generator - -class EncoderDecoder(nn.Module): - """ - A standard Encoder-Decoder architecture. Base for this and many - other models. - """ - def __init__(self, encoder, decoder, generator): - super(EncoderDecoder, self).__init__() - self.encoder = encoder - self.decoder = decoder - self.generator = generator - - def forward(self, src, tgt, src_mask, tgt_mask): - "Take in and process masked src and target sequences." - return self.decode(self.encode(src, src_mask), src_mask, - tgt, tgt_mask) - - def encode(self, src, src_mask): - return self.encoder(inputs_embeds=src, - attention_mask=src_mask)[0] - - def decode(self, memory, src_mask, tgt, tgt_mask): - return self.decoder(input_ids=tgt, - attention_mask=tgt_mask, - encoder_hidden_states=memory, - encoder_attention_mask=src_mask)[0] - - -class BertCapModel(TransformerModel): - - def make_model(self, src_vocab, tgt_vocab, N_enc=6, N_dec=6, - d_model=512, d_ff=2048, h=8, dropout=0.1): - "Helper: Construct a model from hyperparameters." - enc_config = BertConfig(vocab_size=1, - hidden_size=d_model, - num_hidden_layers=N_enc, - num_attention_heads=h, - intermediate_size=d_ff, - hidden_dropout_prob=dropout, - attention_probs_dropout_prob=dropout, - max_position_embeddings=1, - type_vocab_size=1) - dec_config = BertConfig(vocab_size=tgt_vocab, - hidden_size=d_model, - num_hidden_layers=N_dec, - num_attention_heads=h, - intermediate_size=d_ff, - hidden_dropout_prob=dropout, - attention_probs_dropout_prob=dropout, - max_position_embeddings=17, - type_vocab_size=1, - is_decoder=True) - encoder = BertModel(enc_config) - def return_embeds(*args, **kwargs): - return kwargs['inputs_embeds'] - del encoder.embeddings; encoder.embeddings = return_embeds - decoder = BertModel(dec_config) - model = EncoderDecoder( - encoder, - decoder, - Generator(d_model, tgt_vocab)) - return model - - def __init__(self, opt): - super(BertCapModel, self).__init__(opt) - - def core(self, it, fc_feats_ph, att_feats_ph, memory, state, mask): - """ - state = [ys.unsqueeze(0)] - """ - if len(state) == 0: - ys = it.unsqueeze(1) - else: - ys = torch.cat([state[0][0], it.unsqueeze(1)], dim=1) - out = self.model.decode(memory, mask, - ys, - subsequent_mask(ys.size(1)) - .to(memory.device)) - return out[:, -1], [ys.unsqueeze(0)] diff --git a/captioning/models/CaptionModel.py b/captioning/models/CaptionModel.py deleted file mode 100644 index 221ecd1e173d2e20e0103d4cde328d82bfd6b66c..0000000000000000000000000000000000000000 --- a/captioning/models/CaptionModel.py +++ /dev/null @@ -1,407 +0,0 @@ -# This file contains ShowAttendTell and AllImg model - -# ShowAttendTell is from Show, Attend and Tell: Neural Image Caption Generation with Visual Attention -# https://arxiv.org/abs/1502.03044 - -# AllImg is a model where -# img feature is concatenated with word embedding at every time step as the input of lstm -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import numpy as np -import torch -import torch.nn as nn -import torch.nn.functional as F -from torch.autograd import * -from ..utils import misc as utils -from . import utils as model_utils - - -class CaptionModel(nn.Module): - def __init__(self): - super(CaptionModel, self).__init__() - - # implements beam search - # calls beam_step and returns the final set of beams - # augments log-probabilities with diversity terms when number of groups > 1 - - def forward(self, *args, **kwargs): - mode = kwargs.get('mode', 'forward') - if 'mode' in kwargs: - del kwargs['mode'] - return getattr(self, '_'+mode)(*args, **kwargs) - - def beam_search(self, init_state, init_logprobs, *args, **kwargs): - - # function computes the similarity score to be augmented - def add_diversity(beam_seq_table, logprobs, t, divm, diversity_lambda, bdash): - local_time = t - divm - unaug_logprobs = logprobs.clone() - batch_size = beam_seq_table[0].shape[0] - - if divm > 0: - change = logprobs.new_zeros(batch_size, logprobs.shape[-1]) - for prev_choice in range(divm): - prev_decisions = beam_seq_table[prev_choice][:, :, local_time] # Nxb - for prev_labels in range(bdash): - change.scatter_add_(1, prev_decisions[:, prev_labels].unsqueeze(-1), change.new_ones(batch_size, 1)) - - if local_time == 0: - logprobs = logprobs - change * diversity_lambda - else: - logprobs = logprobs - self.repeat_tensor(bdash, change) * diversity_lambda - - return logprobs, unaug_logprobs - - - # does one step of classical beam search - - def beam_step(logprobs, unaug_logprobs, beam_size, t, beam_seq, beam_seq_logprobs, beam_logprobs_sum, state): - #INPUTS: - #logprobs: probabilities augmented after diversity N*bxV - #beam_size: obvious - #t : time instant - #beam_seq : tensor contanining the beams - #beam_seq_logprobs: tensor contanining the beam logprobs - #beam_logprobs_sum: tensor contanining joint logprobs - #OUPUTS: - #beam_seq : tensor containing the word indices of the decoded captions Nxbxl - #beam_seq_logprobs : log-probability of each decision made, NxbxlxV - #beam_logprobs_sum : joint log-probability of each beam Nxb - - batch_size = beam_logprobs_sum.shape[0] - vocab_size = logprobs.shape[-1] - logprobs = logprobs.reshape(batch_size, -1, vocab_size) # NxbxV - if t == 0: - assert logprobs.shape[1] == 1 - beam_logprobs_sum = beam_logprobs_sum[:, :1] - candidate_logprobs = beam_logprobs_sum.unsqueeze(-1) + logprobs # beam_logprobs_sum Nxb logprobs is NxbxV - ys, ix = torch.sort(candidate_logprobs.reshape(candidate_logprobs.shape[0], -1), -1, True) - ys, ix = ys[:,:beam_size], ix[:,:beam_size] - beam_ix = ix // vocab_size # Nxb which beam - selected_ix = ix % vocab_size # Nxb # which world - state_ix = (beam_ix + torch.arange(batch_size).type_as(beam_ix).unsqueeze(-1) * logprobs.shape[1]).reshape(-1) # N*b which in Nxb beams - - - if t > 0: - # gather according to beam_ix - assert (beam_seq.gather(1, beam_ix.unsqueeze(-1).expand_as(beam_seq)) == beam_seq.reshape(-1, beam_seq.shape[-1])[state_ix].view_as(beam_seq)).all() - beam_seq = beam_seq.gather(1, beam_ix.unsqueeze(-1).expand_as(beam_seq)) - - beam_seq_logprobs = beam_seq_logprobs.gather(1, beam_ix.unsqueeze(-1).unsqueeze(-1).expand_as(beam_seq_logprobs)) - - beam_seq = torch.cat([beam_seq, selected_ix.unsqueeze(-1)], -1) # beam_seq Nxbxl - beam_logprobs_sum = beam_logprobs_sum.gather(1, beam_ix) + \ - logprobs.reshape(batch_size, -1).gather(1, ix) - assert (beam_logprobs_sum == ys).all() - _tmp_beam_logprobs = unaug_logprobs[state_ix].reshape(batch_size, -1, vocab_size) - beam_logprobs = unaug_logprobs.reshape(batch_size, -1, vocab_size).gather(1, beam_ix.unsqueeze(-1).expand(-1, -1, vocab_size)) # NxbxV - assert (_tmp_beam_logprobs == beam_logprobs).all() - beam_seq_logprobs = torch.cat([ - beam_seq_logprobs, - beam_logprobs.reshape(batch_size, -1, 1, vocab_size)], 2) - - new_state = [None for _ in state] - for _ix in range(len(new_state)): - # copy over state in previous beam q to new beam at vix - new_state[_ix] = state[_ix][:, state_ix] - state = new_state - return beam_seq,beam_seq_logprobs,beam_logprobs_sum,state - - # Start diverse_beam_search - opt = kwargs['opt'] - temperature = opt.get('temperature', 1) # This should not affect beam search, but will affect dbs - beam_size = opt.get('beam_size', 10) - group_size = opt.get('group_size', 1) - diversity_lambda = opt.get('diversity_lambda', 0.5) - decoding_constraint = opt.get('decoding_constraint', 0) - remove_bad_endings = opt.get('remove_bad_endings', 0) - suppress_UNK = opt.get('suppress_UNK', 0) - length_penalty = utils.penalty_builder(opt.get('length_penalty', '')) - bdash = beam_size // group_size # beam per group - - batch_size = init_logprobs.shape[0] - device = init_logprobs.device - # INITIALIZATIONS - beam_seq_table = [torch.LongTensor(batch_size, bdash, 0).to(device) for _ in range(group_size)] - beam_seq_logprobs_table = [torch.FloatTensor(batch_size, bdash, 0, self.vocab_size + 1).to(device) for _ in range(group_size)] - beam_logprobs_sum_table = [torch.zeros(batch_size, bdash).to(device) for _ in range(group_size)] - - # logprobs # logprobs predicted in last time step, shape (beam_size, vocab_size+1) - done_beams_table = [[[] for __ in range(group_size)] for _ in range(batch_size)] - # state_table = [list(torch.unbind(_)) for _ in torch.stack(init_state).chunk(group_size, 2)] - # state_table = list(zip(*[_.reshape(-1, batch_size * bdash, group_size, *_.shape[2:]).chunk(group_size, 2) for _ in init_state])) - state_table = [[_.clone() for _ in init_state] for _ in range(group_size)] - # logprobs_table = list(init_logprobs.reshape(batch_size * bdash, group_size, -1).chunk(group_size, 0)) - logprobs_table = [init_logprobs.clone() for _ in range(group_size)] - # END INIT - - # Chunk elements in the args - args = list(args) - args = model_utils.split_tensors(group_size, args) # For each arg, turn (Bbg)x... to (Bb)x(g)x... - if self.__class__.__name__ == 'AttEnsemble': - args = [[[args[j][i][k] for i in range(len(self.models))] for j in range(len(args))] for k in range(group_size)] # group_name, arg_name, model_name - else: - args = [[args[i][j] for i in range(len(args))] for j in range(group_size)] - - for t in range(self.seq_length + group_size - 1): - for divm in range(group_size): - if t >= divm and t <= self.seq_length + divm - 1: - # add diversity - logprobs = logprobs_table[divm] - # suppress previous word - if decoding_constraint and t-divm > 0: - logprobs.scatter_(1, beam_seq_table[divm][:, :, t-divm-1].reshape(-1, 1).to(device), float('-inf')) - if remove_bad_endings and t-divm > 0: - logprobs[torch.from_numpy(np.isin(beam_seq_table[divm][:, :, t-divm-1].cpu().numpy(), self.bad_endings_ix)).reshape(-1), 0] = float('-inf') - # suppress UNK tokens in the decoding - if suppress_UNK and hasattr(self, 'vocab') and self.vocab[str(logprobs.size(1)-1)] == 'UNK': - logprobs[:,logprobs.size(1)-1] = logprobs[:, logprobs.size(1)-1] - 1000 - # diversity is added here - # the function directly modifies the logprobs values and hence, we need to return - # the unaugmented ones for sorting the candidates in the end. # for historical - # reasons :-) - logprobs, unaug_logprobs = add_diversity(beam_seq_table,logprobs,t,divm,diversity_lambda,bdash) - - # infer new beams - beam_seq_table[divm],\ - beam_seq_logprobs_table[divm],\ - beam_logprobs_sum_table[divm],\ - state_table[divm] = beam_step(logprobs, - unaug_logprobs, - bdash, - t-divm, - beam_seq_table[divm], - beam_seq_logprobs_table[divm], - beam_logprobs_sum_table[divm], - state_table[divm]) - - # if time's up... or if end token is reached then copy beams - for b in range(batch_size): - is_end = beam_seq_table[divm][b, :, t-divm] == self.eos_idx - assert beam_seq_table[divm].shape[-1] == t-divm+1 - if t == self.seq_length + divm - 1: - is_end.fill_(1) - for vix in range(bdash): - if is_end[vix]: - final_beam = { - 'seq': beam_seq_table[divm][b, vix].clone(), - 'logps': beam_seq_logprobs_table[divm][b, vix].clone(), - 'unaug_p': beam_seq_logprobs_table[divm][b, vix].sum().item(), - 'p': beam_logprobs_sum_table[divm][b, vix].item() - } - final_beam['p'] = length_penalty(t-divm+1, final_beam['p']) - done_beams_table[b][divm].append(final_beam) - beam_logprobs_sum_table[divm][b, is_end] -= 1000 - - # move the current group one step forward in time - - it = beam_seq_table[divm][:, :, t-divm].reshape(-1).to(logprobs.device) - logprobs_table[divm], state_table[divm] = self.get_logprobs_state(it, *(args[divm] + [state_table[divm]])) - logprobs_table[divm] = F.log_softmax(logprobs_table[divm] / temperature, dim=-1) - - # all beams are sorted by their log-probabilities - done_beams_table = [[sorted(done_beams_table[b][i], key=lambda x: -x['p'])[:bdash] for i in range(group_size)] for b in range(batch_size)] - done_beams = [sum(_, []) for _ in done_beams_table] - return done_beams - - def old_beam_search(self, init_state, init_logprobs, *args, **kwargs): - - # function computes the similarity score to be augmented - def add_diversity(beam_seq_table, logprobsf, t, divm, diversity_lambda, bdash): - local_time = t - divm - unaug_logprobsf = logprobsf.clone() - for prev_choice in range(divm): - prev_decisions = beam_seq_table[prev_choice][local_time] - for sub_beam in range(bdash): - for prev_labels in range(bdash): - logprobsf[sub_beam][prev_decisions[prev_labels]] = logprobsf[sub_beam][prev_decisions[prev_labels]] - diversity_lambda - return unaug_logprobsf - - # does one step of classical beam search - - def beam_step(logprobsf, unaug_logprobsf, beam_size, t, beam_seq, beam_seq_logprobs, beam_logprobs_sum, state): - #INPUTS: - #logprobsf: probabilities augmented after diversity - #beam_size: obvious - #t : time instant - #beam_seq : tensor contanining the beams - #beam_seq_logprobs: tensor contanining the beam logprobs - #beam_logprobs_sum: tensor contanining joint logprobs - #OUPUTS: - #beam_seq : tensor containing the word indices of the decoded captions - #beam_seq_logprobs : log-probability of each decision made, same size as beam_seq - #beam_logprobs_sum : joint log-probability of each beam - - ys,ix = torch.sort(logprobsf,1,True) - candidates = [] - cols = min(beam_size, ys.size(1)) - rows = beam_size - if t == 0: - rows = 1 - for c in range(cols): # for each column (word, essentially) - for q in range(rows): # for each beam expansion - #compute logprob of expanding beam q with word in (sorted) position c - local_logprob = ys[q,c].item() - candidate_logprob = beam_logprobs_sum[q] + local_logprob - # local_unaug_logprob = unaug_logprobsf[q,ix[q,c]] - candidates.append({'c':ix[q,c], 'q':q, 'p':candidate_logprob, 'r':unaug_logprobsf[q]}) - candidates = sorted(candidates, key=lambda x: -x['p']) - - new_state = [_.clone() for _ in state] - #beam_seq_prev, beam_seq_logprobs_prev - if t >= 1: - #we''ll need these as reference when we fork beams around - beam_seq_prev = beam_seq[:t].clone() - beam_seq_logprobs_prev = beam_seq_logprobs[:t].clone() - for vix in range(beam_size): - v = candidates[vix] - #fork beam index q into index vix - if t >= 1: - beam_seq[:t, vix] = beam_seq_prev[:, v['q']] - beam_seq_logprobs[:t, vix] = beam_seq_logprobs_prev[:, v['q']] - #rearrange recurrent states - for state_ix in range(len(new_state)): - # copy over state in previous beam q to new beam at vix - new_state[state_ix][:, vix] = state[state_ix][:, v['q']] # dimension one is time step - #append new end terminal at the end of this beam - beam_seq[t, vix] = v['c'] # c'th word is the continuation - beam_seq_logprobs[t, vix] = v['r'] # the raw logprob here - beam_logprobs_sum[vix] = v['p'] # the new (sum) logprob along this beam - state = new_state - return beam_seq,beam_seq_logprobs,beam_logprobs_sum,state,candidates - - # Start diverse_beam_search - opt = kwargs['opt'] - temperature = opt.get('temperature', 1) # This should not affect beam search, but will affect dbs - beam_size = opt.get('beam_size', 10) - group_size = opt.get('group_size', 1) - diversity_lambda = opt.get('diversity_lambda', 0.5) - decoding_constraint = opt.get('decoding_constraint', 0) - remove_bad_endings = opt.get('remove_bad_endings', 0) - suppress_UNK = opt.get('suppress_UNK', 0) - length_penalty = utils.penalty_builder(opt.get('length_penalty', '')) - bdash = beam_size // group_size # beam per group - - # INITIALIZATIONS - beam_seq_table = [torch.LongTensor(self.seq_length, bdash).zero_() for _ in range(group_size)] - beam_seq_logprobs_table = [torch.FloatTensor(self.seq_length, bdash, self.vocab_size + 1).zero_() for _ in range(group_size)] - beam_logprobs_sum_table = [torch.zeros(bdash) for _ in range(group_size)] - - # logprobs # logprobs predicted in last time step, shape (beam_size, vocab_size+1) - done_beams_table = [[] for _ in range(group_size)] - # state_table = [list(torch.unbind(_)) for _ in torch.stack(init_state).chunk(group_size, 2)] - state_table = list(zip(*[_.chunk(group_size, 1) for _ in init_state])) - logprobs_table = list(init_logprobs.chunk(group_size, 0)) - # END INIT - - # Chunk elements in the args - args = list(args) - if self.__class__.__name__ == 'AttEnsemble': - args = [[_.chunk(group_size) if _ is not None else [None]*group_size for _ in args_] for args_ in args] # arg_name, model_name, group_name - args = [[[args[j][i][k] for i in range(len(self.models))] for j in range(len(args))] for k in range(group_size)] # group_name, arg_name, model_name - else: - args = [_.chunk(group_size) if _ is not None else [None]*group_size for _ in args] - args = [[args[i][j] for i in range(len(args))] for j in range(group_size)] - - for t in range(self.seq_length + group_size - 1): - for divm in range(group_size): - if t >= divm and t <= self.seq_length + divm - 1: - # add diversity - logprobsf = logprobs_table[divm] - # suppress previous word - if decoding_constraint and t-divm > 0: - logprobsf.scatter_(1, beam_seq_table[divm][t-divm-1].unsqueeze(1).to(logprobsf.device), float('-inf')) - if remove_bad_endings and t-divm > 0: - logprobsf[torch.from_numpy(np.isin(beam_seq_table[divm][t-divm-1].cpu().numpy(), self.bad_endings_ix)), 0] = float('-inf') - # suppress UNK tokens in the decoding - if suppress_UNK and hasattr(self, 'vocab') and self.vocab[str(logprobsf.size(1)-1)] == 'UNK': - logprobsf[:,logprobsf.size(1)-1] = logprobsf[:, logprobsf.size(1)-1] - 1000 - # diversity is added here - # the function directly modifies the logprobsf values and hence, we need to return - # the unaugmented ones for sorting the candidates in the end. # for historical - # reasons :-) - unaug_logprobsf = add_diversity(beam_seq_table,logprobsf,t,divm,diversity_lambda,bdash) - - # infer new beams - beam_seq_table[divm],\ - beam_seq_logprobs_table[divm],\ - beam_logprobs_sum_table[divm],\ - state_table[divm],\ - candidates_divm = beam_step(logprobsf, - unaug_logprobsf, - bdash, - t-divm, - beam_seq_table[divm], - beam_seq_logprobs_table[divm], - beam_logprobs_sum_table[divm], - state_table[divm]) - - # if time's up... or if end token is reached then copy beams - for vix in range(bdash): - if beam_seq_table[divm][t-divm,vix] == self.eos_idx or t == self.seq_length + divm - 1: - final_beam = { - 'seq': beam_seq_table[divm][:, vix].clone(), - 'logps': beam_seq_logprobs_table[divm][:, vix].clone(), - 'unaug_p': beam_seq_logprobs_table[divm][:, vix].sum().item(), - 'p': beam_logprobs_sum_table[divm][vix].item() - } - final_beam['p'] = length_penalty(t-divm+1, final_beam['p']) - done_beams_table[divm].append(final_beam) - # don't continue beams from finished sequences - beam_logprobs_sum_table[divm][vix] = -1000 - - # move the current group one step forward in time - - it = beam_seq_table[divm][t-divm].to(logprobsf.device) - logprobs_table[divm], state_table[divm] = self.get_logprobs_state(it, *(args[divm] + [state_table[divm]])) - logprobs_table[divm] = F.log_softmax(logprobs_table[divm] / temperature, dim=-1) - - # all beams are sorted by their log-probabilities - done_beams_table = [sorted(done_beams_table[i], key=lambda x: -x['p'])[:bdash] for i in range(group_size)] - done_beams = sum(done_beams_table, []) - return done_beams - - def sample_next_word(self, logprobs, sample_method, temperature): - if sample_method == 'greedy': - sampleLogprobs, it = torch.max(logprobs.data, 1) - it = it.view(-1).long() - elif sample_method == 'gumbel': # gumbel softmax - # ref: https://gist.github.com/yzh119/fd2146d2aeb329d067568a493b20172f - def sample_gumbel(shape, eps=1e-20): - U = torch.rand(shape).to(logprobs.device) - return -torch.log(-torch.log(U + eps) + eps) - def gumbel_softmax_sample(logits, temperature): - y = logits + sample_gumbel(logits.size()) - return F.log_softmax(y / temperature, dim=-1) - _logprobs = gumbel_softmax_sample(logprobs, temperature) - _, it = torch.max(_logprobs.data, 1) - sampleLogprobs = logprobs.gather(1, it.unsqueeze(1)) # gather the logprobs at sampled positions - else: - logprobs = logprobs / temperature - if sample_method.startswith('top'): # topk sampling - top_num = float(sample_method[3:]) - if 0 < top_num < 1: - # nucleus sampling from # The Curious Case of Neural Text Degeneration - probs = F.softmax(logprobs, dim=1) - sorted_probs, sorted_indices = torch.sort(probs, descending=True, dim=1) - _cumsum = sorted_probs.cumsum(1) - mask = _cumsum < top_num - mask = torch.cat([torch.ones_like(mask[:,:1]), mask[:,:-1]], 1) - sorted_probs = sorted_probs * mask.to(sorted_probs) - sorted_probs = sorted_probs / sorted_probs.sum(1, keepdim=True) - logprobs.scatter_(1, sorted_indices, sorted_probs.log()) - else: - the_k = int(top_num) - tmp = torch.empty_like(logprobs).fill_(float('-inf')) - topk, indices = torch.topk(logprobs, the_k, dim=1) - tmp = tmp.scatter(1, indices, topk) - logprobs = tmp - it = torch.distributions.Categorical(logits=logprobs.detach()).sample() - sampleLogprobs = logprobs.gather(1, it.unsqueeze(1)) # gather the logprobs at sampled positions - return it, sampleLogprobs - - - def decode_sequence(self, seq): - return utils.decode_sequence(self.vocab, seq) diff --git a/captioning/models/FCModel.py b/captioning/models/FCModel.py deleted file mode 100644 index d3b8340c228e3f6039677e55540d87feb2765d62..0000000000000000000000000000000000000000 --- a/captioning/models/FCModel.py +++ /dev/null @@ -1,204 +0,0 @@ -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import torch -import torch.nn as nn -import torch.nn.functional as F -from torch.autograd import * -from . import utils - -from .CaptionModel import CaptionModel - -class LSTMCore(nn.Module): - def __init__(self, opt): - super(LSTMCore, self).__init__() - self.input_encoding_size = opt.input_encoding_size - self.rnn_size = opt.rnn_size - self.drop_prob_lm = opt.drop_prob_lm - - # Build a LSTM - self.i2h = nn.Linear(self.input_encoding_size, 5 * self.rnn_size) - self.h2h = nn.Linear(self.rnn_size, 5 * self.rnn_size) - self.dropout = nn.Dropout(self.drop_prob_lm) - - def forward(self, xt, state): - - all_input_sums = self.i2h(xt) + self.h2h(state[0][-1]) - sigmoid_chunk = all_input_sums.narrow(1, 0, 3 * self.rnn_size) - sigmoid_chunk = torch.sigmoid(sigmoid_chunk) - in_gate = sigmoid_chunk.narrow(1, 0, self.rnn_size) - forget_gate = sigmoid_chunk.narrow(1, self.rnn_size, self.rnn_size) - out_gate = sigmoid_chunk.narrow(1, self.rnn_size * 2, self.rnn_size) - - in_transform = torch.max(\ - all_input_sums.narrow(1, 3 * self.rnn_size, self.rnn_size), - all_input_sums.narrow(1, 4 * self.rnn_size, self.rnn_size)) - next_c = forget_gate * state[1][-1] + in_gate * in_transform - next_h = out_gate * torch.tanh(next_c) - - output = self.dropout(next_h) - state = (next_h.unsqueeze(0), next_c.unsqueeze(0)) - return output, state - -class FCModel(CaptionModel): - def __init__(self, opt): - super(FCModel, self).__init__() - self.vocab_size = opt.vocab_size - self.input_encoding_size = opt.input_encoding_size - self.rnn_type = opt.rnn_type - self.rnn_size = opt.rnn_size - self.num_layers = opt.num_layers - self.drop_prob_lm = opt.drop_prob_lm - self.seq_length = opt.seq_length - self.fc_feat_size = opt.fc_feat_size - - self.ss_prob = 0.0 # Schedule sampling probability - - self.img_embed = nn.Linear(self.fc_feat_size, self.input_encoding_size) - self.core = LSTMCore(opt) - self.embed = nn.Embedding(self.vocab_size + 1, self.input_encoding_size) - self.logit = nn.Linear(self.rnn_size, self.vocab_size + 1) - - self.init_weights() - - def init_weights(self): - initrange = 0.1 - self.embed.weight.data.uniform_(-initrange, initrange) - self.logit.bias.data.fill_(0) - self.logit.weight.data.uniform_(-initrange, initrange) - - def init_hidden(self, bsz): - weight = self.logit.weight - if self.rnn_type == 'lstm': - return (weight.new_zeros(self.num_layers, bsz, self.rnn_size), - weight.new_zeros(self.num_layers, bsz, self.rnn_size)) - else: - return weight.new_zeros(self.num_layers, bsz, self.rnn_size) - - def _forward(self, fc_feats, att_feats, seq, att_masks=None): - batch_size = fc_feats.size(0) - seq_per_img = seq.shape[0] // batch_size - state = self.init_hidden(batch_size*seq_per_img) - outputs = [] - - if seq_per_img > 1: - fc_feats = utils.repeat_tensors(seq_per_img, fc_feats) - - for i in range(seq.size(1) + 1): - if i == 0: - xt = self.img_embed(fc_feats) - else: - if self.training and i >= 2 and self.ss_prob > 0.0: # otherwiste no need to sample - sample_prob = fc_feats.data.new(batch_size*seq_per_img).uniform_(0, 1) - sample_mask = sample_prob < self.ss_prob - if sample_mask.sum() == 0: - it = seq[:, i-1].clone() - else: - sample_ind = sample_mask.nonzero().view(-1) - it = seq[:, i-1].data.clone() - #prob_prev = torch.exp(outputs[-1].data.index_select(0, sample_ind)) # fetch prev distribution: shape Nx(M+1) - #it.index_copy_(0, sample_ind, torch.multinomial(prob_prev, 1).view(-1)) - prob_prev = torch.exp(outputs[-1].data) # fetch prev distribution: shape Nx(M+1) - it.index_copy_(0, sample_ind, torch.multinomial(prob_prev, 1).view(-1).index_select(0, sample_ind)) - else: - it = seq[:, i-1].clone() - # break if all the sequences end - if i >= 2 and seq[:, i-1].sum() == 0: - break - xt = self.embed(it) - - output, state = self.core(xt, state) - output = F.log_softmax(self.logit(output), dim=1) - outputs.append(output) - - return torch.cat([_.unsqueeze(1) for _ in outputs[1:]], 1).contiguous() - - def get_logprobs_state(self, it, state): - # 'it' is contains a word index - xt = self.embed(it) - - output, state = self.core(xt, state) - logprobs = F.log_softmax(self.logit(output), dim=1) - - return logprobs, state - - def _sample_beam(self, fc_feats, att_feats, att_masks=None, opt={}): - beam_size = opt.get('beam_size', 10) - batch_size = fc_feats.size(0) - - assert beam_size <= self.vocab_size + 1, 'lets assume this for now, otherwise this corner case causes a few headaches down the road. can be dealt with in future if needed' - seq = torch.LongTensor(self.seq_length, batch_size).zero_() - seqLogprobs = torch.FloatTensor(self.seq_length, batch_size, self.vocab_size + 1) - # lets process every image independently for now, for simplicity - - self.done_beams = [[] for _ in range(batch_size)] - for k in range(batch_size): - state = self.init_hidden(beam_size) - for t in range(2): - if t == 0: - xt = self.img_embed(fc_feats[k:k+1]).expand(beam_size, self.input_encoding_size) - elif t == 1: # input - it = fc_feats.data.new(beam_size).long().zero_() - xt = self.embed(it) - - output, state = self.core(xt, state) - logprobs = F.log_softmax(self.logit(output), dim=1) - - self.done_beams[k] = self.beam_search(state, logprobs, opt=opt) - seq[:, k] = self.done_beams[k][0]['seq'] # the first beam has highest cumulative score - seqLogprobs[:, k] = self.done_beams[k][0]['logps'] - # return the samples and their log likelihoods - return seq.transpose(0, 1), seqLogprobs.transpose(0, 1) - - def _sample(self, fc_feats, att_feats, att_masks=None, opt={}): - sample_method = opt.get('sample_method', 'greedy') - beam_size = opt.get('beam_size', 1) - temperature = opt.get('temperature', 1.0) - if beam_size > 1 and sample_method in ['greedy', 'beam_search']: - return self._sample_beam(fc_feats, att_feats, opt) - - batch_size = fc_feats.size(0) - state = self.init_hidden(batch_size) - seq = fc_feats.new_zeros(batch_size, self.seq_length, dtype=torch.long) - seqLogprobs = fc_feats.new_zeros(batch_size, self.seq_length, self.vocab_size + 1) - for t in range(self.seq_length + 2): - if t == 0: - xt = self.img_embed(fc_feats) - else: - if t == 1: # input - it = fc_feats.data.new(batch_size).long().zero_() - xt = self.embed(it) - - output, state = self.core(xt, state) - logprobs = F.log_softmax(self.logit(output), dim=1) - - # sample the next_word - if t == self.seq_length + 1: # skip if we achieve maximum length - break - if sample_method == 'greedy': - sampleLogprobs, it = torch.max(logprobs.data, 1) - it = it.view(-1).long() - else: - if temperature == 1.0: - prob_prev = torch.exp(logprobs.data).cpu() # fetch prev distribution: shape Nx(M+1) - else: - # scale logprobs by temperature - prob_prev = torch.exp(torch.div(logprobs.data, temperature)).cpu() - it = torch.multinomial(prob_prev, 1).to(logprobs.device) - sampleLogprobs = logprobs.gather(1, it) # gather the logprobs at sampled positions - it = it.view(-1).long() # and flatten indices for downstream processing - - if t >= 1: - # stop when all finished - if t == 1: - unfinished = it > 0 - else: - unfinished = unfinished & (it > 0) - it = it * unfinished.type_as(it) - seq[:,t-1] = it #seq[t] the input of t+2 time step - seqLogprobs[:,t-1] = sampleLogprobs.view(-1) - if unfinished.sum() == 0: - break - - return seq, seqLogprobs diff --git a/captioning/models/M2Transformer.py b/captioning/models/M2Transformer.py deleted file mode 100644 index 0428e5d429645bf340a9d72a4b2d0ae6a14bb2bc..0000000000000000000000000000000000000000 --- a/captioning/models/M2Transformer.py +++ /dev/null @@ -1,98 +0,0 @@ -""" -Instruction to use meshed_memory_transformer (https://arxiv.org/abs/1912.08226) - -pip install git+https://github.com/ruotianluo/meshed-memory-transformer.git - -Note: -Currently m2transformer is not performing as well as original transformer. Not sure why? Still investigating. -""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import torch -import torch.nn as nn -import torch.nn.functional as F - -import copy -import math -import numpy as np - -from .CaptionModel import CaptionModel -from .AttModel import sort_pack_padded_sequence, pad_unsort_packed_sequence, pack_wrapper, AttModel - -try: - from m2transformer.models.transformer import Transformer, MemoryAugmentedEncoder, MeshedDecoder, ScaledDotProductAttentionMemory -except: - print('meshed-memory-transformer not installed; please run `pip install git+https://github.com/ruotianluo/meshed-memory-transformer.git`') -from .TransformerModel import subsequent_mask, TransformerModel - - -class M2TransformerModel(TransformerModel): - - def make_model(self, src_vocab, tgt_vocab, N_enc=6, N_dec=6, - d_model=512, d_ff=2048, h=8, dropout=0.1): - "Helper: Construct a model from hyperparameters." - encoder = MemoryAugmentedEncoder(N_enc, 0, attention_module=ScaledDotProductAttentionMemory, - attention_module_kwargs={'m': 40}) - # Another implementation is to use MultiLevelEncoder + att_embed - decoder = MeshedDecoder(tgt_vocab, 54, N_dec, -1) # -1 is padding; - model = Transformer(0, encoder, decoder) # 0 is bos - return model - - def __init__(self, opt): - super(M2TransformerModel, self).__init__(opt) - delattr(self, 'att_embed') - self.att_embed = lambda x: x # The visual embed is in the MAEncoder - # Notes: The dropout in MAEncoder is different from my att_embed, mine is 0.5? - # Also the attention mask seems wrong in MAEncoder too...intersting - - def logit(self, x): # unsafe way - return x # M2transformer always output logsoftmax - - def _prepare_feature(self, fc_feats, att_feats, att_masks): - - att_feats, seq, att_masks, seq_mask = self._prepare_feature_forward(att_feats, att_masks) - memory, att_masks = self.model.encoder(att_feats) - - return fc_feats[...,:0], att_feats[...,:0], memory, att_masks - - def _forward(self, fc_feats, att_feats, seq, att_masks=None): - if seq.ndim == 3: # B * seq_per_img * seq_len - seq = seq.reshape(-1, seq.shape[2]) - att_feats, seq, att_masks, seq_mask = self._prepare_feature_forward(att_feats, att_masks, seq) - - seq = seq.clone() - seq[~seq_mask.any(-2)] = -1 # Make padding to be -1 (my dataloader uses 0 as padding) - outputs = self.model(att_feats, seq) - - return outputs - - def core(self, it, fc_feats_ph, att_feats_ph, memory, state, mask): - """ - state = [ys.unsqueeze(0)] - """ - if len(state) == 0: - ys = it.unsqueeze(1) - else: - ys = torch.cat([state[0][0], it.unsqueeze(1)], dim=1) - out = self.model.decoder(ys, memory, mask) - return out[:, -1], [ys.unsqueeze(0)] - - def _sample_beam(self, fc_feats, att_feats, att_masks=None, opt={}): - beam_size = opt.get('beam_size', 10) - group_size = opt.get('group_size', 1) - sample_n = opt.get('sample_n', 10) - assert sample_n == 1 or sample_n == beam_size // group_size, 'when beam search, sample_n == 1 or beam search' - - att_feats, _, __, ___ = self._prepare_feature_forward(att_feats, att_masks) - seq, logprobs, seqLogprobs = self.model.beam_search(att_feats, self.seq_length, 0, - beam_size, return_probs=True, out_size=beam_size) - seq = seq.reshape(-1, *seq.shape[2:]) - seqLogprobs = seqLogprobs.reshape(-1, *seqLogprobs.shape[2:]) - - # if not (seqLogprobs.gather(-1, seq.unsqueeze(-1)).squeeze(-1) == logprobs.reshape(-1, logprobs.shape[-1])).all(): - # import pudb;pu.db - # seqLogprobs = logprobs.reshape(-1, logprobs.shape[-1]).unsqueeze(-1).expand(-1,-1,seqLogprobs.shape[-1]) - return seq, seqLogprobs \ No newline at end of file diff --git a/captioning/models/ShowTellModel.py b/captioning/models/ShowTellModel.py deleted file mode 100644 index 2f3463b64f988aa61d90838ddcf8ac89053c3377..0000000000000000000000000000000000000000 --- a/captioning/models/ShowTellModel.py +++ /dev/null @@ -1,174 +0,0 @@ -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import torch -import torch.nn as nn -import torch.nn.functional as F -from torch.autograd import * -from . import utils - -from .CaptionModel import CaptionModel - -class ShowTellModel(CaptionModel): - def __init__(self, opt): - super(ShowTellModel, self).__init__() - self.vocab_size = opt.vocab_size - self.input_encoding_size = opt.input_encoding_size - self.rnn_type = opt.rnn_type - self.rnn_size = opt.rnn_size - self.num_layers = opt.num_layers - self.drop_prob_lm = opt.drop_prob_lm - self.seq_length = opt.seq_length - self.fc_feat_size = opt.fc_feat_size - - self.ss_prob = 0.0 # Schedule sampling probability - - self.img_embed = nn.Linear(self.fc_feat_size, self.input_encoding_size) - self.core = getattr(nn, self.rnn_type.upper())(self.input_encoding_size, self.rnn_size, self.num_layers, bias=False, dropout=self.drop_prob_lm) - self.embed = nn.Embedding(self.vocab_size + 1, self.input_encoding_size) - self.logit = nn.Linear(self.rnn_size, self.vocab_size + 1) - self.dropout = nn.Dropout(self.drop_prob_lm) - - self.init_weights() - - def init_weights(self): - initrange = 0.1 - self.embed.weight.data.uniform_(-initrange, initrange) - self.logit.bias.data.fill_(0) - self.logit.weight.data.uniform_(-initrange, initrange) - - def init_hidden(self, bsz): - weight = self.logit.weight - if self.rnn_type == 'lstm': - return (weight.new_zeros(self.num_layers, bsz, self.rnn_size), - weight.new_zeros(self.num_layers, bsz, self.rnn_size)) - else: - return weight.new_zeros(self.num_layers, bsz, self.rnn_size) - - def _forward(self, fc_feats, att_feats, seq, att_masks=None): - batch_size = fc_feats.size(0) - seq_per_img = seq.shape[0] // batch_size - state = self.init_hidden(batch_size*seq_per_img) - outputs = [] - - if seq_per_img > 1: - fc_feats = utils.repeat_tensors(seq_per_img, fc_feats) - - for i in range(seq.size(1) + 1): - if i == 0: - xt = self.img_embed(fc_feats) - else: - if self.training and i >= 2 and self.ss_prob > 0.0: # otherwiste no need to sample - sample_prob = fc_feats.data.new(batch_size*seq_per_img).uniform_(0, 1) - sample_mask = sample_prob < self.ss_prob - if sample_mask.sum() == 0: - it = seq[:, i-1].clone() - else: - sample_ind = sample_mask.nonzero().view(-1) - it = seq[:, i-1].data.clone() - #prob_prev = torch.exp(outputs[-1].data.index_select(0, sample_ind)) # fetch prev distribution: shape Nx(M+1) - #it.index_copy_(0, sample_ind, torch.multinomial(prob_prev, 1).view(-1)) - prob_prev = torch.exp(outputs[-1].data) # fetch prev distribution: shape Nx(M+1) - it.index_copy_(0, sample_ind, torch.multinomial(prob_prev, 1).view(-1).index_select(0, sample_ind)) - else: - it = seq[:, i-1].clone() - # break if all the sequences end - if i >= 2 and seq[:, i-1].data.sum() == 0: - break - xt = self.embed(it) - - output, state = self.core(xt.unsqueeze(0), state) - output = F.log_softmax(self.logit(self.dropout(output.squeeze(0))), dim=1) - outputs.append(output) - - return torch.cat([_.unsqueeze(1) for _ in outputs[1:]], 1).contiguous() - - def get_logprobs_state(self, it, state): - # 'it' contains a word index - xt = self.embed(it) - - output, state = self.core(xt.unsqueeze(0), state) - logprobs = F.log_softmax(self.logit(self.dropout(output.squeeze(0))), dim=1) - - return logprobs, state - - def _sample_beam(self, fc_feats, att_feats, att_masks=None, opt={}): - beam_size = opt.get('beam_size', 10) - batch_size = fc_feats.size(0) - - assert beam_size <= self.vocab_size + 1, 'lets assume this for now, otherwise this corner case causes a few headaches down the road. can be dealt with in future if needed' - seq = torch.LongTensor(self.seq_length, batch_size).zero_() - seqLogprobs = torch.FloatTensor(self.seq_length, batch_size) - # lets process every image independently for now, for simplicity - - self.done_beams = [[] for _ in range(batch_size)] - for k in range(batch_size): - state = self.init_hidden(beam_size) - for t in range(2): - if t == 0: - xt = self.img_embed(fc_feats[k:k+1]).expand(beam_size, self.input_encoding_size) - elif t == 1: # input - it = fc_feats.data.new(beam_size).long().zero_() - xt = self.embed(it) - - output, state = self.core(xt.unsqueeze(0), state) - logprobs = F.log_softmax(self.logit(self.dropout(output.squeeze(0))), dim=1) - - self.done_beams[k] = self.beam_search(state, logprobs, opt=opt) - seq[:, k] = self.done_beams[k][0]['seq'] # the first beam has highest cumulative score - seqLogprobs[:, k] = self.done_beams[k][0]['logps'] - # return the samples and their log likelihoods - return seq.transpose(0, 1), seqLogprobs.transpose(0, 1) - - def _sample(self, fc_feats, att_feats, att_masks=None, opt={}): - sample_method = opt.get('sample_method', 'greedy') - beam_size = opt.get('beam_size', 1) - temperature = opt.get('temperature', 1.0) - if beam_size > 1 and sample_method in ['greedy', 'beam_search']: - return self.sample_beam(fc_feats, att_feats, opt) - - batch_size = fc_feats.size(0) - state = self.init_hidden(batch_size) - seq = fc_feats.new_zeros(batch_size, self.seq_length, dtype=torch.long) - seqLogprobs = fc_feats.new_zeros(batch_size, self.seq_length) - for t in range(self.seq_length + 2): - if t == 0: - xt = self.img_embed(fc_feats) - else: - if t == 1: # input - it = fc_feats.data.new(batch_size).long().zero_() - xt = self.embed(it) - - output, state = self.core(xt.unsqueeze(0), state) - logprobs = F.log_softmax(self.logit(self.dropout(output.squeeze(0))), dim=1) - - # sample the next word - if t == self.seq_length + 1: # skip if we achieve maximum length - break - if sample_method == 'greedy': - sampleLogprobs, it = torch.max(logprobs.data, 1) - it = it.view(-1).long() - else: - if temperature == 1.0: - prob_prev = torch.exp(logprobs.data).cpu() # fetch prev distribution: shape Nx(M+1) - else: - # scale logprobs by temperature - prob_prev = torch.exp(torch.div(logprobs.data, temperature)).cpu() - it = torch.multinomial(prob_prev, 1).to(logprobs.device) - sampleLogprobs = logprobs.gather(1, it) # gather the logprobs at sampled positions - it = it.view(-1).long() # and flatten indices for downstream processing - - if t >= 1: - # stop when all finished - if t == 1: - unfinished = it > 0 - else: - unfinished = unfinished & (it > 0) - it = it * unfinished.type_as(it) - seq[:,t-1] = it #seq[t] the input of t+2 time step - seqLogprobs[:,t-1] = sampleLogprobs.view(-1) - if unfinished.sum() == 0: - break - - return seq, seqLogprobs \ No newline at end of file diff --git a/captioning/models/TransformerModel.py b/captioning/models/TransformerModel.py deleted file mode 100644 index 70a27a25e968cf906bdde461e054fed77c08f70b..0000000000000000000000000000000000000000 --- a/captioning/models/TransformerModel.py +++ /dev/null @@ -1,363 +0,0 @@ -# This file contains Transformer network -# Most of the code is copied from http://nlp.seas.harvard.edu/2018/04/03/attention.html - -# The cfg name correspondance: -# N=num_layers -# d_model=input_encoding_size -# d_ff=rnn_size -# h is always 8 - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import torch -import torch.nn as nn -import torch.nn.functional as F -from . import utils - -import copy -import math -import numpy as np - -from .CaptionModel import CaptionModel -from .AttModel import sort_pack_padded_sequence, pad_unsort_packed_sequence, pack_wrapper, AttModel - -class EncoderDecoder(nn.Module): - """ - A standard Encoder-Decoder architecture. Base for this and many - other models. - """ - def __init__(self, encoder, decoder, src_embed, tgt_embed, generator): - super(EncoderDecoder, self).__init__() - self.encoder = encoder - self.decoder = decoder - self.src_embed = src_embed - self.tgt_embed = tgt_embed - self.generator = generator - - def forward(self, src, tgt, src_mask, tgt_mask): - "Take in and process masked src and target sequences." - return self.decode(self.encode(src, src_mask), src_mask, - tgt, tgt_mask) - - def encode(self, src, src_mask): - return self.encoder(self.src_embed(src), src_mask) - - def decode(self, memory, src_mask, tgt, tgt_mask): - return self.decoder(self.tgt_embed(tgt), memory, src_mask, tgt_mask) - -class Generator(nn.Module): - "Define standard linear + softmax generation step." - def __init__(self, d_model, vocab): - super(Generator, self).__init__() - self.proj = nn.Linear(d_model, vocab) - - def forward(self, x): - return F.log_softmax(self.proj(x), dim=-1) - -def clones(module, N): - "Produce N identical layers." - return nn.ModuleList([copy.deepcopy(module) for _ in range(N)]) - -class Encoder(nn.Module): - "Core encoder is a stack of N layers" - def __init__(self, layer, N): - super(Encoder, self).__init__() - self.layers = clones(layer, N) - self.norm = LayerNorm(layer.size) - - def forward(self, x, mask): - "Pass the input (and mask) through each layer in turn." - for layer in self.layers: - x = layer(x, mask) - return self.norm(x) - -class LayerNorm(nn.Module): - "Construct a layernorm module (See citation for details)." - def __init__(self, features, eps=1e-6): - super(LayerNorm, self).__init__() - self.a_2 = nn.Parameter(torch.ones(features)) - self.b_2 = nn.Parameter(torch.zeros(features)) - self.eps = eps - - def forward(self, x): - mean = x.mean(-1, keepdim=True) - std = x.std(-1, keepdim=True) - return self.a_2 * (x - mean) / (std + self.eps) + self.b_2 - -class SublayerConnection(nn.Module): - """ - A residual connection followed by a layer norm. - Note for code simplicity the norm is first as opposed to last. - """ - def __init__(self, size, dropout): - super(SublayerConnection, self).__init__() - self.norm = LayerNorm(size) - self.dropout = nn.Dropout(dropout) - - def forward(self, x, sublayer): - "Apply residual connection to any sublayer with the same size." - return x + self.dropout(sublayer(self.norm(x))) - -class EncoderLayer(nn.Module): - "Encoder is made up of self-attn and feed forward (defined below)" - def __init__(self, size, self_attn, feed_forward, dropout): - super(EncoderLayer, self).__init__() - self.self_attn = self_attn - self.feed_forward = feed_forward - self.sublayer = clones(SublayerConnection(size, dropout), 2) - self.size = size - - def forward(self, x, mask): - "Follow Figure 1 (left) for connections." - x = self.sublayer[0](x, lambda x: self.self_attn(x, x, x, mask)) - return self.sublayer[1](x, self.feed_forward) - -class Decoder(nn.Module): - "Generic N layer decoder with masking." - def __init__(self, layer, N): - super(Decoder, self).__init__() - self.layers = clones(layer, N) - self.norm = LayerNorm(layer.size) - - def forward(self, x, memory, src_mask, tgt_mask): - for layer in self.layers: - x = layer(x, memory, src_mask, tgt_mask) - return self.norm(x) - -class DecoderLayer(nn.Module): - "Decoder is made of self-attn, src-attn, and feed forward (defined below)" - def __init__(self, size, self_attn, src_attn, feed_forward, dropout): - super(DecoderLayer, self).__init__() - self.size = size - self.self_attn = self_attn - self.src_attn = src_attn - self.feed_forward = feed_forward - self.sublayer = clones(SublayerConnection(size, dropout), 3) - - def forward(self, x, memory, src_mask, tgt_mask): - "Follow Figure 1 (right) for connections." - m = memory - x = self.sublayer[0](x, lambda x: self.self_attn(x, x, x, tgt_mask)) - x = self.sublayer[1](x, lambda x: self.src_attn(x, m, m, src_mask)) - return self.sublayer[2](x, self.feed_forward) - -def subsequent_mask(size): - "Mask out subsequent positions." - attn_shape = (1, size, size) - subsequent_mask = np.triu(np.ones(attn_shape), k=1).astype('uint8') - return torch.from_numpy(subsequent_mask) == 0 - -def attention(query, key, value, mask=None, dropout=None): - "Compute 'Scaled Dot Product Attention'" - d_k = query.size(-1) - scores = torch.matmul(query, key.transpose(-2, -1)) \ - / math.sqrt(d_k) - if mask is not None: - scores = scores.masked_fill(mask == 0, float('-inf')) - p_attn = F.softmax(scores, dim = -1) - if dropout is not None: - p_attn = dropout(p_attn) - return torch.matmul(p_attn, value), p_attn - -class MultiHeadedAttention(nn.Module): - def __init__(self, h, d_model, dropout=0.1): - "Take in model size and number of heads." - super(MultiHeadedAttention, self).__init__() - assert d_model % h == 0 - # We assume d_v always equals d_k - self.d_k = d_model // h - self.h = h - self.linears = clones(nn.Linear(d_model, d_model), 4) - self.attn = None - self.dropout = nn.Dropout(p=dropout) - - def forward(self, query, key, value, mask=None): - "Implements Figure 2" - if mask is not None: - # Same mask applied to all h heads. - mask = mask.unsqueeze(1) - nbatches = query.size(0) - - # 1) Do all the linear projections in batch from d_model => h x d_k - query, key, value = \ - [l(x).view(nbatches, -1, self.h, self.d_k).transpose(1, 2) - for l, x in zip(self.linears, (query, key, value))] - - # 2) Apply attention on all the projected vectors in batch. - x, self.attn = attention(query, key, value, mask=mask, - dropout=self.dropout) - - # 3) "Concat" using a view and apply a final linear. - x = x.transpose(1, 2).contiguous() \ - .view(nbatches, -1, self.h * self.d_k) - return self.linears[-1](x) - -class PositionwiseFeedForward(nn.Module): - "Implements FFN equation." - def __init__(self, d_model, d_ff, dropout=0.1): - super(PositionwiseFeedForward, self).__init__() - self.w_1 = nn.Linear(d_model, d_ff) - self.w_2 = nn.Linear(d_ff, d_model) - self.dropout = nn.Dropout(dropout) - - def forward(self, x): - return self.w_2(self.dropout(F.relu(self.w_1(x)))) - -class Embeddings(nn.Module): - def __init__(self, d_model, vocab): - super(Embeddings, self).__init__() - self.lut = nn.Embedding(vocab, d_model) - self.d_model = d_model - - def forward(self, x): - return self.lut(x) * math.sqrt(self.d_model) - -class PositionalEncoding(nn.Module): - "Implement the PE function." - def __init__(self, d_model, dropout, max_len=5000): - super(PositionalEncoding, self).__init__() - self.dropout = nn.Dropout(p=dropout) - - # Compute the positional encodings once in log space. - pe = torch.zeros(max_len, d_model) - position = torch.arange(0, max_len).unsqueeze(1).float() - div_term = torch.exp(torch.arange(0, d_model, 2).float() * - -(math.log(10000.0) / d_model)) - pe[:, 0::2] = torch.sin(position * div_term) - pe[:, 1::2] = torch.cos(position * div_term) - pe = pe.unsqueeze(0) - self.register_buffer('pe', pe) - - def forward(self, x): - x = x + self.pe[:, :x.size(1)] - return self.dropout(x) - -class TransformerModel(AttModel): - - def make_model(self, src_vocab, tgt_vocab, N_enc=6, N_dec=6, - d_model=512, d_ff=2048, h=8, dropout=0.1): - "Helper: Construct a model from hyperparameters." - c = copy.deepcopy - attn = MultiHeadedAttention(h, d_model, dropout) - ff = PositionwiseFeedForward(d_model, d_ff, dropout) - position = PositionalEncoding(d_model, dropout) - model = EncoderDecoder( - Encoder(EncoderLayer(d_model, c(attn), c(ff), dropout), N_enc), - Decoder(DecoderLayer(d_model, c(attn), c(attn), - c(ff), dropout), N_dec), - lambda x:x, # nn.Sequential(Embeddings(d_model, src_vocab), c(position)), - nn.Sequential(Embeddings(d_model, tgt_vocab), c(position)), - Generator(d_model, tgt_vocab)) - - # This was important from their code. - # Initialize parameters with Glorot / fan_avg. - for p in model.parameters(): - if p.dim() > 1: - nn.init.xavier_uniform_(p) - return model - - def __init__(self, opt): - super(TransformerModel, self).__init__(opt) - self.opt = opt - # self.config = yaml.load(open(opt.config_file)) - - self.N_enc = getattr(opt, 'N_enc', opt.num_layers) - self.N_dec = getattr(opt, 'N_dec', opt.num_layers) - self.d_model = getattr(opt, 'd_model', opt.input_encoding_size) - self.d_ff = getattr(opt, 'd_ff', opt.rnn_size) - self.h = getattr(opt, 'num_att_heads', 8) - self.dropout = getattr(opt, 'dropout', 0.1) - - delattr(self, 'att_embed') - self.att_embed = nn.Sequential(*( - ((nn.BatchNorm1d(self.att_feat_size),) if self.use_bn else ())+ - (nn.Linear(self.att_feat_size, self.d_model), - nn.ReLU(), - nn.Dropout(self.drop_prob_lm))+ - ((nn.BatchNorm1d(self.d_model),) if self.use_bn==2 else ()))) - - delattr(self, 'embed') - self.embed = lambda x : x - delattr(self, 'fc_embed') - self.fc_embed = lambda x : x - delattr(self, 'logit') - del self.ctx2att - - tgt_vocab = self.vocab_size + 1 - - - self.model = self.make_model(0, tgt_vocab, - N_enc=self.N_enc, - N_dec=self.N_dec, - d_model=self.d_model, - d_ff=self.d_ff, - h=self.h, - dropout=self.dropout) - - def logit(self, x): # unsafe way - return self.model.generator.proj(x) - - def init_hidden(self, bsz): - return [] - - def _prepare_feature(self, fc_feats, att_feats, att_masks): - - att_feats, seq, att_masks, seq_mask = self._prepare_feature_forward(att_feats, att_masks) - memory = self.model.encode(att_feats, att_masks) - - return fc_feats[...,:0], att_feats[...,:0], memory, att_masks - - def _prepare_feature_forward(self, att_feats, att_masks=None, seq=None): - att_feats, att_masks = self.clip_att(att_feats, att_masks) - - att_feats = pack_wrapper(self.att_embed, att_feats, att_masks) - - if att_masks is None: - att_masks = att_feats.new_ones(att_feats.shape[:2], dtype=torch.long) - att_masks = att_masks.unsqueeze(-2) - - if seq is not None: - # crop the last one - # seq = seq[:,:-1] - seq_mask = (seq.data != self.eos_idx) & (seq.data != self.pad_idx) - seq_mask[:,0] = 1 # bos - - seq_mask = seq_mask.unsqueeze(-2) - seq_mask = seq_mask & subsequent_mask(seq.size(-1)).to(seq_mask) - - seq_per_img = seq.shape[0] // att_feats.shape[0] - if seq_per_img > 1: - att_feats, att_masks = utils.repeat_tensors(seq_per_img, - [att_feats, att_masks] - ) - else: - seq_mask = None - - return att_feats, seq, att_masks, seq_mask - - def _forward(self, fc_feats, att_feats, seq, att_masks=None): - if seq.ndim == 3: # B * seq_per_img * seq_len - seq = seq.reshape(-1, seq.shape[2]) - att_feats, seq, att_masks, seq_mask = self._prepare_feature_forward(att_feats, att_masks, seq) - - out = self.model(att_feats, seq, att_masks, seq_mask) - - outputs = self.model.generator(out) - return outputs - # return torch.cat([_.unsqueeze(1) for _ in outputs], 1) - - def core(self, it, fc_feats_ph, att_feats_ph, memory, state, mask): - """ - state = [ys.unsqueeze(0)] - """ - if len(state) == 0: - ys = it.unsqueeze(1) - else: - ys = torch.cat([state[0][0], it.unsqueeze(1)], dim=1) - out = self.model.decode(memory, mask, - ys, - subsequent_mask(ys.size(1)) - .to(memory.device)) - return out[:, -1], [ys.unsqueeze(0)] \ No newline at end of file diff --git a/captioning/models/__init__.py b/captioning/models/__init__.py deleted file mode 100644 index 29f7a9cb48b9397ed0b658c15580b43c5ae1300d..0000000000000000000000000000000000000000 --- a/captioning/models/__init__.py +++ /dev/null @@ -1,73 +0,0 @@ -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import os -import copy - -import numpy as np -import torch - -from .ShowTellModel import ShowTellModel -from .FCModel import FCModel -from .AttModel import * -from .TransformerModel import TransformerModel -from .cachedTransformer import TransformerModel as cachedTransformer -from .BertCapModel import BertCapModel -from .M2Transformer import M2TransformerModel -from .AoAModel import AoAModel - -def setup(opt): - if opt.caption_model in ['fc', 'show_tell']: - print('Warning: %s model is mostly deprecated; many new features are not supported.' %opt.caption_model) - if opt.caption_model == 'fc': - print('Use newfc instead of fc') - if opt.caption_model == 'fc': - model = FCModel(opt) - elif opt.caption_model == 'language_model': - model = LMModel(opt) - elif opt.caption_model == 'newfc': - model = NewFCModel(opt) - elif opt.caption_model == 'show_tell': - model = ShowTellModel(opt) - # Att2in model in self-critical - elif opt.caption_model == 'att2in': - model = Att2inModel(opt) - # Att2in model with two-layer MLP img embedding and word embedding - elif opt.caption_model == 'att2in2': - model = Att2in2Model(opt) - elif opt.caption_model == 'att2all2': - print('Warning: this is not a correct implementation of the att2all model in the original paper.') - model = Att2all2Model(opt) - # Adaptive Attention model from Knowing when to look - elif opt.caption_model == 'adaatt': - model = AdaAttModel(opt) - # Adaptive Attention with maxout lstm - elif opt.caption_model == 'adaattmo': - model = AdaAttMOModel(opt) - # Top-down attention model - elif opt.caption_model in ['topdown', 'updown']: - model = UpDownModel(opt) - # StackAtt - elif opt.caption_model == 'stackatt': - model = StackAttModel(opt) - # DenseAtt - elif opt.caption_model == 'denseatt': - model = DenseAttModel(opt) - # Transformer - elif opt.caption_model == 'transformer': - if getattr(opt, 'cached_transformer', False): - model = cachedTransformer(opt) - else: - model = TransformerModel(opt) - # AoANet - elif opt.caption_model == 'aoa': - model = AoAModel(opt) - elif opt.caption_model == 'bert': - model = BertCapModel(opt) - elif opt.caption_model == 'm2transformer': - model = M2TransformerModel(opt) - else: - raise Exception("Caption model not supported: {}".format(opt.caption_model)) - - return model diff --git a/captioning/models/cachedTransformer.py b/captioning/models/cachedTransformer.py deleted file mode 100644 index 719701cb348a11255a36d554ad350dcfc87e5121..0000000000000000000000000000000000000000 --- a/captioning/models/cachedTransformer.py +++ /dev/null @@ -1,420 +0,0 @@ -# This file contains Transformer network -# Most of the code is copied from http://nlp.seas.harvard.edu/2018/04/03/attention.html - -# The cfg name correspondance: -# N=num_layers -# d_model=input_encoding_size -# d_ff=rnn_size -# h is always 8 - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import torch -import torch.nn as nn -import torch.nn.functional as F -from . import utils - -import copy -import math -import numpy as np - -from .CaptionModel import CaptionModel -from .AttModel import sort_pack_padded_sequence, pad_unsort_packed_sequence, pack_wrapper, AttModel - -class EncoderDecoder(nn.Module): - """ - A standard Encoder-Decoder architecture. Base for this and many - other models. - """ - def __init__(self, encoder, decoder, src_embed, tgt_embed, generator): - super(EncoderDecoder, self).__init__() - self.encoder = encoder - self.decoder = decoder - self.src_embed = src_embed - self.tgt_embed = tgt_embed - self.generator = generator - - def forward(self, src, tgt, src_mask, tgt_mask): - "Take in and process masked src and target sequences." - return self.decode(self.encode(src, src_mask), src_mask, - tgt, tgt_mask) - - def encode(self, src, src_mask): - return self.encoder(self.src_embed(src), src_mask) - - def decode(self, memory, src_mask, tgt, tgt_mask, past=None): - return self.decoder(self.tgt_embed(tgt), memory, src_mask, tgt_mask, past=past) - -class Generator(nn.Module): - "Define standard linear + softmax generation step." - def __init__(self, d_model, vocab): - super(Generator, self).__init__() - self.proj = nn.Linear(d_model, vocab) - - def forward(self, x): - return F.log_softmax(self.proj(x), dim=-1) - -def clones(module, N): - "Produce N identical layers." - return nn.ModuleList([copy.deepcopy(module) for _ in range(N)]) - -class Encoder(nn.Module): - "Core encoder is a stack of N layers" - def __init__(self, layer, N): - super(Encoder, self).__init__() - self.layers = clones(layer, N) - self.norm = LayerNorm(layer.size) - - def forward(self, x, mask): - "Pass the input (and mask) through each layer in turn." - for layer in self.layers: - x = layer(x, mask) - return self.norm(x) - -class LayerNorm(nn.Module): - "Construct a layernorm module (See citation for details)." - def __init__(self, features, eps=1e-6): - super(LayerNorm, self).__init__() - self.a_2 = nn.Parameter(torch.ones(features)) - self.b_2 = nn.Parameter(torch.zeros(features)) - self.eps = eps - - def forward(self, x): - mean = x.mean(-1, keepdim=True) - std = x.std(-1, keepdim=True) - return self.a_2 * (x - mean) / (std + self.eps) + self.b_2 - -class SublayerConnection(nn.Module): - """ - A residual connection followed by a layer norm. - Note for code simplicity the norm is first as opposed to last. - """ - def __init__(self, size, dropout): - super(SublayerConnection, self).__init__() - self.norm = LayerNorm(size) - self.dropout = nn.Dropout(dropout) - - def forward(self, x, sublayer): - "Apply residual connection to any sublayer with the same size." - _x = sublayer(self.norm(x)) - if type(_x) is tuple: # for multi-head attention that returns past - return x + self.dropout(_x[0]), _x[1] - return x + self.dropout(_x) - -class EncoderLayer(nn.Module): - "Encoder is made up of self-attn and feed forward (defined below)" - def __init__(self, size, self_attn, feed_forward, dropout): - super(EncoderLayer, self).__init__() - self.self_attn = self_attn - self.feed_forward = feed_forward - self.sublayer = clones(SublayerConnection(size, dropout), 2) - self.size = size - - def forward(self, x, mask): - "Follow Figure 1 (left) for connections." - x = self.sublayer[0](x, lambda x: self.self_attn(x, x, x, mask)) - return self.sublayer[1](x, self.feed_forward) - -class Decoder(nn.Module): - "Generic N layer decoder with masking." - def __init__(self, layer, N): - super(Decoder, self).__init__() - self.layers = clones(layer, N) - self.norm = LayerNorm(layer.size) - - def forward(self, x, memory, src_mask, tgt_mask, past=None): - if past is not None: - present = [[], []] - x = x[:, -1:] - tgt_mask = tgt_mask[:, -1:] if tgt_mask is not None else None - past = list(zip(past[0].split(2, dim=0), past[1].split(2, dim=0))) - else: - past = [None] * len(self.layers) - for i, (layer, layer_past) in enumerate(zip(self.layers, past)): - x = layer(x, memory, src_mask, tgt_mask, - layer_past) - if layer_past is not None: - present[0].append(x[1][0]) - present[1].append(x[1][1]) - x = x[0] - if past[0] is None: - return self.norm(x) - else: - return self.norm(x), [torch.cat(present[0], 0), torch.cat(present[1], 0)] - - -class DecoderLayer(nn.Module): - "Decoder is made of self-attn, src-attn, and feed forward (defined below)" - def __init__(self, size, self_attn, src_attn, feed_forward, dropout): - super(DecoderLayer, self).__init__() - self.size = size - self.self_attn = self_attn - self.src_attn = src_attn - self.feed_forward = feed_forward - self.sublayer = clones(SublayerConnection(size, dropout), 3) - - def forward(self, x, memory, src_mask, tgt_mask, layer_past=None): - "Follow Figure 1 (right) for connections." - m = memory - if layer_past is None: - x = self.sublayer[0](x, lambda x: self.self_attn(x, x, x, tgt_mask)) - x = self.sublayer[1](x, lambda x: self.src_attn(x, m, m, src_mask)) - return self.sublayer[2](x, self.feed_forward) - else: - present = [None, None] - x, present[0] = self.sublayer[0](x, lambda x: self.self_attn(x, x, x, tgt_mask, layer_past[0])) - x, present[1] = self.sublayer[1](x, lambda x: self.src_attn(x, m, m, src_mask, layer_past[1])) - return self.sublayer[2](x, self.feed_forward), present - -def subsequent_mask(size): - "Mask out subsequent positions." - attn_shape = (1, size, size) - subsequent_mask = np.triu(np.ones(attn_shape), k=1).astype('uint8') - return torch.from_numpy(subsequent_mask) == 0 - -def attention(query, key, value, mask=None, dropout=None): - "Compute 'Scaled Dot Product Attention'" - d_k = query.size(-1) - scores = torch.matmul(query, key.transpose(-2, -1)) \ - / math.sqrt(d_k) - if mask is not None: - scores = scores.masked_fill(mask == 0, float('-inf')) - p_attn = F.softmax(scores, dim = -1) - if dropout is not None: - p_attn = dropout(p_attn) - return torch.matmul(p_attn, value), p_attn - -class MultiHeadedAttention(nn.Module): - def __init__(self, h, d_model, dropout=0.1): - "Take in model size and number of heads." - super(MultiHeadedAttention, self).__init__() - assert d_model % h == 0 - # We assume d_v always equals d_k - self.d_k = d_model // h - self.h = h - self.linears = clones(nn.Linear(d_model, d_model), 4) - self.attn = None - self.dropout = nn.Dropout(p=dropout) - - def forward(self, query, key, value, mask=None, layer_past=None): - "Implements Figure 2" - if mask is not None: - # Same mask applied to all h heads. - mask = mask.unsqueeze(1) - nbatches = query.size(0) - - # The past works differently here. For self attn, the query and key be updated incrementailly - # For src_attn the past is fixed. - - # For src_attn, when the layer past is ready - if layer_past is not None and layer_past.shape[2] == key.shape[1] > 1: # suppose memory size always greater than 1 - query = self.linears[0](query) - key, value = layer_past[0], layer_past[1] - present = torch.stack([key, value]) - else: - # 1) Do all the linear projections in batch from d_model => h x d_k - query, key, value = \ - [l(x) for l, x in zip(self.linears, (query, key, value))] - - # self attn + past OR the first time step of src attn - if layer_past is not None and not (layer_past.shape[2] == key.shape[1] > 1): - past_key, past_value = layer_past[0], layer_past[1] - key = torch.cat((past_key, key), dim=1) - value = torch.cat((past_value, value), dim=1) - present = torch.stack([key, value]) - - query, key, value = \ - [x.view(nbatches, -1, self.h, self.d_k).transpose(1, 2) - for x in [query, key, value]] - - # 2) Apply attention on all the projected vectors in batch. - x, self.attn = attention(query, key, value, mask=mask, - dropout=self.dropout) - - # 3) "Concat" using a view and apply a final linear. - x = x.transpose(1, 2).contiguous() \ - .view(nbatches, -1, self.h * self.d_k) - if layer_past is not None: - return self.linears[-1](x), present - else: - return self.linears[-1](x) - -class PositionwiseFeedForward(nn.Module): - "Implements FFN equation." - def __init__(self, d_model, d_ff, dropout=0.1): - super(PositionwiseFeedForward, self).__init__() - self.w_1 = nn.Linear(d_model, d_ff) - self.w_2 = nn.Linear(d_ff, d_model) - self.dropout = nn.Dropout(dropout) - - def forward(self, x): - return self.w_2(self.dropout(F.relu(self.w_1(x)))) - -class Embeddings(nn.Module): - def __init__(self, d_model, vocab): - super(Embeddings, self).__init__() - self.lut = nn.Embedding(vocab, d_model) - self.d_model = d_model - - def forward(self, x): - return self.lut(x) * math.sqrt(self.d_model) - -class PositionalEncoding(nn.Module): - "Implement the PE function." - def __init__(self, d_model, dropout, max_len=5000): - super(PositionalEncoding, self).__init__() - self.dropout = nn.Dropout(p=dropout) - - # Compute the positional encodings once in log space. - pe = torch.zeros(max_len, d_model) - position = torch.arange(0, max_len).unsqueeze(1).float() - div_term = torch.exp(torch.arange(0, d_model, 2).float() * - -(math.log(10000.0) / d_model)) - pe[:, 0::2] = torch.sin(position * div_term) - pe[:, 1::2] = torch.cos(position * div_term) - pe = pe.unsqueeze(0) - self.register_buffer('pe', pe) - - def forward(self, x): - x = x + self.pe[:, :x.size(1)] - return self.dropout(x) - -class TransformerModel(AttModel): - - def make_model(self, src_vocab, tgt_vocab, N_enc=6, N_dec=6, - d_model=512, d_ff=2048, h=8, dropout=0.1): - "Helper: Construct a model from hyperparameters." - c = copy.deepcopy - attn = MultiHeadedAttention(h, d_model, dropout) - ff = PositionwiseFeedForward(d_model, d_ff, dropout) - position = PositionalEncoding(d_model, dropout) - model = EncoderDecoder( - Encoder(EncoderLayer(d_model, c(attn), c(ff), dropout), N_enc), - Decoder(DecoderLayer(d_model, c(attn), c(attn), - c(ff), dropout), N_dec), - lambda x:x, # nn.Sequential(Embeddings(d_model, src_vocab), c(position)), - nn.Sequential(Embeddings(d_model, tgt_vocab), c(position)), - Generator(d_model, tgt_vocab)) - - # This was important from their code. - # Initialize parameters with Glorot / fan_avg. - for p in model.parameters(): - if p.dim() > 1: - nn.init.xavier_uniform_(p) - return model - - def __init__(self, opt): - super(TransformerModel, self).__init__(opt) - self.opt = opt - # self.config = yaml.load(open(opt.config_file)) - - self.N_enc = getattr(opt, 'N_enc', opt.num_layers) - self.N_dec = getattr(opt, 'N_dec', opt.num_layers) - self.d_model = getattr(opt, 'd_model', opt.input_encoding_size) - self.d_ff = getattr(opt, 'd_ff', opt.rnn_size) - self.h = getattr(opt, 'num_att_heads', 8) - self.dropout = getattr(opt, 'dropout', 0.1) - - delattr(self, 'att_embed') - self.att_embed = nn.Sequential(*( - ((nn.BatchNorm1d(self.att_feat_size),) if self.use_bn else ())+ - (nn.Linear(self.att_feat_size, self.d_model), - nn.ReLU(), - nn.Dropout(self.drop_prob_lm))+ - ((nn.BatchNorm1d(self.d_model),) if self.use_bn==2 else ()))) - - delattr(self, 'embed') - self.embed = lambda x : x - delattr(self, 'fc_embed') - self.fc_embed = lambda x : x - delattr(self, 'logit') - del self.ctx2att - - tgt_vocab = self.vocab_size + 1 - - - self.model = self.make_model(0, tgt_vocab, - N_enc=self.N_enc, - N_dec=self.N_dec, - d_model=self.d_model, - d_ff=self.d_ff, - h=self.h, - dropout=self.dropout) - - def logit(self, x): # unsafe way - return self.model.generator.proj(x) - - def init_hidden(self, bsz): - return [] - - def _prepare_feature(self, fc_feats, att_feats, att_masks): - - att_feats, seq, att_masks, seq_mask = self._prepare_feature_forward(att_feats, att_masks) - memory = self.model.encode(att_feats, att_masks) - - return fc_feats[...,:0], att_feats[...,:0], memory, att_masks - - def _prepare_feature_forward(self, att_feats, att_masks=None, seq=None): - att_feats, att_masks = self.clip_att(att_feats, att_masks) - - att_feats = pack_wrapper(self.att_embed, att_feats, att_masks) - - if att_masks is None: - att_masks = att_feats.new_ones(att_feats.shape[:2], dtype=torch.long) - att_masks = att_masks.unsqueeze(-2) - - if seq is not None: - # crop the last one - # seq = seq[:,:-1] - seq_mask = (seq.data != self.eos_idx) & (seq.data != self.pad_idx) - seq_mask[:,0] = 1 # bos - - seq_mask = seq_mask.unsqueeze(-2) - seq_mask = seq_mask & subsequent_mask(seq.size(-1)).to(seq_mask) - - seq_per_img = seq.shape[0] // att_feats.shape[0] - if seq_per_img > 1: - att_feats, att_masks = utils.repeat_tensors(seq_per_img, - [att_feats, att_masks] - ) - else: - seq_mask = None - - return att_feats, seq, att_masks, seq_mask - - def _forward(self, fc_feats, att_feats, seq, att_masks=None): - if seq.ndim == 3: # B * seq_per_img * seq_len - seq = seq.reshape(-1, seq.shape[2]) - att_feats, seq, att_masks, seq_mask = self._prepare_feature_forward(att_feats, att_masks, seq) - - out = self.model(att_feats, seq, att_masks, seq_mask) - - outputs = self.model.generator(out) - return outputs - # return torch.cat([_.unsqueeze(1) for _ in outputs], 1) - - def core(self, it, fc_feats_ph, att_feats_ph, memory, state, mask): - """ - state is the precomputed key/value. N_dec x seq_len x d_model - Note: due to the layer norm, it's not equivalant to stateless, - but it seems behaving similar - """ - # state is tokens + past - if len(state) == 0: - ys = it.unsqueeze(1) - # basically empty state, just to let it know to return past - # The second dim has to be batch_size, for beam search purpose - past = [fc_feats_ph.new_zeros(self.N_dec * 2, fc_feats_ph.shape[0], 0, self.d_model), # self - fc_feats_ph.new_zeros(self.N_dec * 2, fc_feats_ph.shape[0], 0, self.d_model)] # src - # 2 for self attn, 2 for src attn - else: - ys = torch.cat([state[0][0], it.unsqueeze(1)], dim=1) - past = state[1:] - out, past = self.model.decode(memory, mask, - ys, # We still feed the full past words, because we need it for position embedding to know the position id - subsequent_mask(ys.size(1)) - .to(memory.device), - past=past) - return out[:, -1], [ys.unsqueeze(0)] + past diff --git a/captioning/models/utils.py b/captioning/models/utils.py deleted file mode 100644 index feb130bceb26aae56b9a849a7131f8fde784a43d..0000000000000000000000000000000000000000 --- a/captioning/models/utils.py +++ /dev/null @@ -1,25 +0,0 @@ -import torch - -def repeat_tensors(n, x): - """ - For a tensor of size Bx..., we repeat it n times, and make it Bnx... - For collections, do nested repeat - """ - if torch.is_tensor(x): - x = x.unsqueeze(1) # Bx1x... - x = x.expand(-1, n, *([-1]*len(x.shape[2:]))) # Bxnx... - x = x.reshape(x.shape[0]*n, *x.shape[2:]) # Bnx... - elif type(x) is list or type(x) is tuple: - x = [repeat_tensors(n, _) for _ in x] - return x - - -def split_tensors(n, x): - if torch.is_tensor(x): - assert x.shape[0] % n == 0 - x = x.reshape(x.shape[0] // n, n, *x.shape[1:]).unbind(1) - elif type(x) is list or type(x) is tuple: - x = [split_tensors(n, _) for _ in x] - elif x is None: - x = [None] * n - return x \ No newline at end of file diff --git a/captioning/modules/loss_wrapper.py b/captioning/modules/loss_wrapper.py deleted file mode 100644 index d86f1e6f7df4a6bc112563294b8bf6bb4d999b98..0000000000000000000000000000000000000000 --- a/captioning/modules/loss_wrapper.py +++ /dev/null @@ -1,127 +0,0 @@ -import torch -from . import losses -from ..utils.rewards import init_scorer, get_self_critical_reward, get_self_critical_clipscore_reward -from ..utils.clipscore import CLIPScore -import numpy as np - -class LossWrapper(torch.nn.Module): - def __init__(self, model, opt): - super(LossWrapper, self).__init__() - self.opt = opt - self.model = model - if opt.label_smoothing > 0: - self.crit = losses.LabelSmoothing(smoothing=opt.label_smoothing) - else: - self.crit = losses.LanguageModelCriterion() - self.rl_crit = losses.RewardCriterion() - self.struc_crit = losses.StructureLosses(opt) - - self.clipscore_model = None - if self.opt.use_clipscore: - use_grammar = getattr(self.opt, 'use_grammar', False) - joint_out = getattr(self.opt, 'joint_out', False) - self.clipscore_model = CLIPScore( - mode=opt.clipscore_mode, - use_grammar=use_grammar, - joint_out=joint_out, - ) - for p in self.clipscore_model.parameters(): - p.requires_grad = False - - if use_grammar: - state_dict = torch.load(self.opt.clip_load_path, map_location='cpu') - self.clipscore_model.load_state_dict(state_dict['state_dict']) - - def forward(self, fc_feats, att_feats, labels, masks, att_masks, gts, gt_indices, - sc_flag, struc_flag, clip_vis_feats=None): - opt = self.opt - - out = {} - if struc_flag: - if opt.structure_loss_weight < 1: - lm_loss = self.crit(self.model(fc_feats, att_feats, labels[..., :-1], att_masks), labels[..., 1:], masks[..., 1:]) - else: - lm_loss = torch.tensor(0).type_as(fc_feats) - if opt.structure_loss_weight > 0: - gen_result, sample_logprobs = self.model(fc_feats, att_feats, att_masks, - opt={'sample_method':opt.train_sample_method, - 'beam_size':opt.train_beam_size, - 'output_logsoftmax': opt.struc_use_logsoftmax or opt.structure_loss_type == 'softmax_margin'\ - or not 'margin' in opt.structure_loss_type, - 'sample_n': opt.train_sample_n}, - mode='sample') - gts = [gts[_] for _ in gt_indices.tolist()] - struc_loss = self.struc_crit(sample_logprobs, gen_result, gts) - else: - struc_loss = {'loss': torch.tensor(0).type_as(fc_feats), - 'reward': torch.tensor(0).type_as(fc_feats)} - loss = (1-opt.structure_loss_weight) * lm_loss + opt.structure_loss_weight * struc_loss['loss'] - out['lm_loss'] = lm_loss - out['struc_loss'] = struc_loss['loss'] - out['reward'] = struc_loss['reward'] - elif not sc_flag: - loss = self.crit(self.model(fc_feats, att_feats, labels[..., :-1], att_masks), labels[..., 1:], masks[..., 1:]) - else: - self.model.eval() - with torch.no_grad(): - greedy_res, _ = self.model(fc_feats, att_feats, att_masks, - mode='sample', - opt={'sample_method': opt.sc_sample_method, - 'beam_size': opt.sc_beam_size}) - self.model.train() - gen_result, sample_logprobs = self.model(fc_feats, att_feats, att_masks, - opt={'sample_method':opt.train_sample_method, - 'beam_size':opt.train_beam_size, - 'sample_n': opt.train_sample_n}, - mode='sample') - gts = [gts[_] for _ in gt_indices.tolist()] - - if getattr(self.opt, 'use_multi_rewards', False): - assert self.opt.use_clipscore - clipscore_reward_normalized, clipscore_unnormalized_mean, grammar_rewards = get_self_critical_clipscore_reward( - greedy_res, gts, gen_result, self.opt, self.clipscore_model, clip_vis_feats, self.model.vocab) - - if self.opt.clipscore_mode == 'clip_s': - out['CLIP-S'] = clipscore_unnormalized_mean - elif self.opt.clipscore_mode == 'refclip_s': - out['RefCLIP-S'] = clipscore_unnormalized_mean - - if getattr(self.opt, 'use_grammar', False): - out['grammar_reward'] = grammar_rewards.mean() - - reward = clipscore_reward_normalized + grammar_rewards - - - else: - assert grammar_rewards is None - - cider_reward_normalized, cider_unnormalized_mean = get_self_critical_reward( - greedy_res, gts, gen_result, self.opt) - out['CIDEr'] = cider_unnormalized_mean - if isinstance(cider_reward_normalized, np.ndarray): - cider_reward_normalized = torch.from_numpy(cider_reward_normalized).to(clipscore_reward_normalized.device) - - reward = clipscore_reward_normalized + cider_reward_normalized - else: - if self.opt.use_clipscore: - clipscore_reward_normalized, clipscore_unnormalized_mean, _ = get_self_critical_clipscore_reward( - greedy_res, gts, gen_result, self.opt, self.clipscore_model, clip_vis_feats, self.model.vocab) - if self.opt.clipscore_mode == 'clip_s': - out['CLIP-S'] = clipscore_unnormalized_mean - elif self.opt.clipscore_mode == 'refclip_s': - out['RefCLIP-S'] = clipscore_unnormalized_mean - reward = clipscore_reward_normalized - else: - cider_reward_normalized, cider_unnormalized_mean = get_self_critical_reward( - greedy_res, gts, gen_result, self.opt) - out['CIDEr'] = cider_unnormalized_mean - reward = cider_reward_normalized - - if isinstance(reward, np.ndarray): - reward = torch.from_numpy(reward) - reward = reward.to(sample_logprobs) - loss = self.rl_crit(sample_logprobs, gen_result.data, reward) - out['reward'] = reward[:,0].mean() - out['loss'] = loss - return out - diff --git a/captioning/modules/losses.py b/captioning/modules/losses.py deleted file mode 100644 index 28d6db59dd70a9418a8a074d54402d6b5823520c..0000000000000000000000000000000000000000 --- a/captioning/modules/losses.py +++ /dev/null @@ -1,218 +0,0 @@ -import torch -import torch.nn as nn -from ..utils.rewards import get_scores, get_self_cider_scores - -class RewardCriterion(nn.Module): - def __init__(self): - super(RewardCriterion, self).__init__() - - def forward(self, input, seq, reward): - input = input.gather(2, seq.unsqueeze(2)).squeeze(2) - - input = input.reshape(-1) - reward = reward.reshape(-1) - mask = (seq>0).to(input) - mask = torch.cat([mask.new(mask.size(0), 1).fill_(1), mask[:, :-1]], 1).reshape(-1) - output = - input * reward * mask - output = torch.sum(output) / torch.sum(mask) - - return output - -class StructureLosses(nn.Module): - """ - This loss is inspired by Classical Structured Prediction Losses for Sequence to Sequence Learning (Edunov et al., 2018). - """ - def __init__(self, opt): - super(StructureLosses, self).__init__() - self.opt = opt - self.loss_type = opt.structure_loss_type - - def forward(self, input, seq, data_gts): - """ - Input is either logits or log softmax - """ - out = {} - - batch_size = input.size(0)# batch_size = sample_size * seq_per_img - seq_per_img = batch_size // len(data_gts) - - assert seq_per_img == self.opt.train_sample_n, seq_per_img - - mask = (seq>0).to(input) - mask = torch.cat([mask.new_full((mask.size(0), 1), 1), mask[:, :-1]], 1) - - scores = get_scores(data_gts, seq, self.opt) - scores = torch.from_numpy(scores).type_as(input).view(-1, seq_per_img) - out['reward'] = scores #.mean() - if self.opt.entropy_reward_weight > 0: - entropy = - (F.softmax(input, dim=2) * F.log_softmax(input, dim=2)).sum(2).data - entropy = (entropy * mask).sum(1) / mask.sum(1) - print('entropy', entropy.mean().item()) - scores = scores + self.opt.entropy_reward_weight * entropy.view(-1, seq_per_img) - # rescale cost to [0,1] - costs = - scores - if self.loss_type == 'risk' or self.loss_type == 'softmax_margin': - costs = costs - costs.min(1, keepdim=True)[0] - costs = costs / costs.max(1, keepdim=True)[0] - # in principle - # Only risk need such rescale - # margin should be alright; Let's try. - - # Gather input: BxTxD -> BxT - input = input.gather(2, seq.unsqueeze(2)).squeeze(2) - - if self.loss_type == 'seqnll': - # input is logsoftmax - input = input * mask - input = input.sum(1) / mask.sum(1) - input = input.view(-1, seq_per_img) - - target = costs.min(1)[1] - output = F.cross_entropy(input, target) - elif self.loss_type == 'risk': - # input is logsoftmax - input = input * mask - input = input.sum(1) - input = input.view(-1, seq_per_img) - - output = (F.softmax(input.exp()) * costs).sum(1).mean() - - # test - # avg_scores = input - # probs = F.softmax(avg_scores.exp_()) - # loss = (probs * costs.type_as(probs)).sum() / input.size(0) - # print(output.item(), loss.item()) - - elif self.loss_type == 'max_margin': - # input is logits - input = input * mask - input = input.sum(1) / mask.sum(1) - input = input.view(-1, seq_per_img) - _, __ = costs.min(1, keepdim=True) - costs_star = _ - input_star = input.gather(1, __) - output = F.relu(costs - costs_star - input_star + input).max(1)[0] / 2 - output = output.mean() - - # sanity test - # avg_scores = input + costs - # scores_with_high_target = avg_scores.clone() - # scores_with_high_target.scatter_(1, costs.min(1)[1].view(-1, 1), 1e10) - - # target_and_offender_index = scores_with_high_target.sort(1, True)[1][:, 0:2] - # avg_scores = avg_scores.gather(1, target_and_offender_index) - # target_index = avg_scores.new_zeros(avg_scores.size(0), dtype=torch.long) - # loss = F.multi_margin_loss(avg_scores, target_index, size_average=True, margin=0) - # print(loss.item() * 2, output.item()) - - elif self.loss_type == 'multi_margin': - # input is logits - input = input * mask - input = input.sum(1) / mask.sum(1) - input = input.view(-1, seq_per_img) - _, __ = costs.min(1, keepdim=True) - costs_star = _ - input_star = input.gather(1, __) - output = F.relu(costs - costs_star - input_star + input) - output = output.mean() - - # sanity test - # avg_scores = input + costs - # loss = F.multi_margin_loss(avg_scores, costs.min(1)[1], margin=0) - # print(output, loss) - - elif self.loss_type == 'softmax_margin': - # input is logsoftmax - input = input * mask - input = input.sum(1) / mask.sum(1) - input = input.view(-1, seq_per_img) - - input = input + costs - target = costs.min(1)[1] - output = F.cross_entropy(input, target) - - elif self.loss_type == 'real_softmax_margin': - # input is logits - # This is what originally defined in Kevin's paper - # The result should be equivalent to softmax_margin - input = input * mask - input = input.sum(1) / mask.sum(1) - input = input.view(-1, seq_per_img) - - input = input + costs - target = costs.min(1)[1] - output = F.cross_entropy(input, target) - - elif self.loss_type == 'new_self_critical': - """ - A different self critical - Self critical uses greedy decoding score as baseline; - This setting uses the average score of the rest samples as baseline - (suppose c1...cn n samples, reward1 = score1 - 1/(n-1)(score2+..+scoren) ) - """ - baseline = (scores.sum(1, keepdim=True) - scores) / (scores.shape[1] - 1) - scores = scores - baseline - # self cider used as reward to promote diversity (not working that much in this way) - if getattr(self.opt, 'self_cider_reward_weight', 0) > 0: - _scores = get_self_cider_scores(data_gts, seq, self.opt) - _scores = torch.from_numpy(_scores).type_as(scores).view(-1, 1) - _scores = _scores.expand_as(scores - 1) - scores += self.opt.self_cider_reward_weight * _scores - output = - input * mask * scores.view(-1, 1) - output = torch.sum(output) / torch.sum(mask) - - out['loss'] = output - return out - -class LanguageModelCriterion(nn.Module): - def __init__(self): - super(LanguageModelCriterion, self).__init__() - - def forward(self, input, target, mask): - if target.ndim == 3: - target = target.reshape(-1, target.shape[2]) - mask = mask.reshape(-1, mask.shape[2]) - # truncate to the same size - target = target[:, :input.size(1)] - mask = mask[:, :input.size(1)].to(input) - - output = -input.gather(2, target.unsqueeze(2)).squeeze(2) * mask - # Average over each token - output = torch.sum(output) / torch.sum(mask) - - return output - -class LabelSmoothing(nn.Module): - "Implement label smoothing." - def __init__(self, size=0, padding_idx=0, smoothing=0.0): - super(LabelSmoothing, self).__init__() - self.criterion = nn.KLDivLoss(size_average=False, reduce=False) - # self.padding_idx = padding_idx - self.confidence = 1.0 - smoothing - self.smoothing = smoothing - # self.size = size - self.true_dist = None - - def forward(self, input, target, mask): - if target.ndim == 3: - target = target.reshape(-1, target.shape[2]) - mask = mask.reshape(-1, mask.shape[2]) - # truncate to the same size - target = target[:, :input.size(1)] - mask = mask[:, :input.size(1)] - - input = input.reshape(-1, input.size(-1)) - target = target.reshape(-1) - mask = mask.reshape(-1).to(input) - - # assert x.size(1) == self.size - self.size = input.size(1) - # true_dist = x.data.clone() - true_dist = input.data.clone() - # true_dist.fill_(self.smoothing / (self.size - 2)) - true_dist.fill_(self.smoothing / (self.size - 1)) - true_dist.scatter_(1, target.data.unsqueeze(1), self.confidence) - # true_dist[:, self.padding_idx] = 0 - # mask = torch.nonzero(target.data == self.padding_idx) - # self.true_dist = true_dist - return (self.criterion(input, true_dist).sum(1) * mask).sum() / mask.sum() \ No newline at end of file diff --git a/captioning/utils/__init__.py b/captioning/utils/__init__.py deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/captioning/utils/clipscore.py b/captioning/utils/clipscore.py deleted file mode 100644 index 0345140d9f7b47e37b3a895915a135e1441c907b..0000000000000000000000000000000000000000 --- a/captioning/utils/clipscore.py +++ /dev/null @@ -1,396 +0,0 @@ -from transformers import CLIPModel, CLIPTokenizer -import os -import json -import argparse -from random import shuffle, seed -import string -# non-standard dependencies: -import h5py -from six.moves import cPickle -import numpy as np -import torch -import torchvision.models as models -import skimage.io - -from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize -from PIL import Image -from torch import nn - - -class CLIPScore(nn.Module): - def __init__(self, clipscore_w=2.5, image_size=224, mode='clip_s', use_grammar=False, joint_out=False): - super(CLIPScore, self).__init__() - # from transformers import CLIPModel, CLIPTokenizer - self.clip_model = CLIPModel.from_pretrained( - 'openai/clip-vit-base-patch32') - self.tokenizer = CLIPTokenizer.from_pretrained( - 'openai/clip-vit-base-patch32') - - self.clip_model.eval() - - self.clipscore_w = clipscore_w - - self.image_transform = self._transform(image_size) - - self.mode = mode - assert mode in ['clip_s', 'refclip_s'] - - self.use_grammar = use_grammar - self.joint_out = joint_out - - if self.use_grammar and joint_out is False: - self.grammar_score_head = nn.Sequential( - nn.Linear(self.clip_model.text_embed_dim, self.clip_model.projection_dim, bias=False), - nn.ReLU(), - nn.Linear(self.clip_model.projection_dim, 2, bias=False) - ) - - def _transform(self, n_px): - return Compose([ - Resize(n_px, interpolation=Image.BICUBIC), - CenterCrop(n_px), - lambda image: image.convert("RGB"), - ToTensor(), - Normalize((0.48145466, 0.4578275, 0.40821073), - (0.26862954, 0.26130258, 0.27577711)), - ]) - - def load_image(self, image_path): - image = Image.open(image_path) - return image - - # @torch.no_grad() - def image_extract(self, image): - if isinstance(image, str): - image = self.load_image(image) - if not isinstance(image, torch.Tensor): - image = self.image_transform(image) - - img_tensor = image.view(-1, 3, 224, 224) - device = next(self.clip_model.parameters()).device - img_tensor = img_tensor.to(device) - - clip_model = self.clip_model - - img_feat = clip_model.vision_model(img_tensor).pooler_output - img_feat = clip_model.visual_projection(img_feat) - img_feat = img_feat / img_feat.norm(dim=-1, keepdim=True) - - return img_feat - - # @torch.no_grad() - def text_extract(self, text, prompt="A photo depicts", proj_norm=True): - if isinstance(text, str): - text_batch = [" ".join([prompt, text])] - elif isinstance(text, list): - text_batch = [" ".join([prompt, txt]) for txt in text] - - if isinstance(text, tuple) and isinstance(text[0], torch.Tensor): - input_ids, attention_mask = text - else: - input_text = text_batch - - tokenized = self.tokenizer( - input_text, return_tensors='pt', padding=True, truncation=True) - - input_ids = tokenized.input_ids - attention_mask = tokenized.attention_mask - - clip_model = self.clip_model - device = next(self.clip_model.parameters()).device - input_ids = input_ids.to(device) - attention_mask = attention_mask.to(device) - - text_feat = clip_model.text_model(input_ids, attention_mask).pooler_output - - if proj_norm: - text_feat = clip_model.text_projection(text_feat) - text_feat = text_feat / text_feat.norm(dim=-1, keepdim=True) - - return text_feat - - # @torch.no_grad() - def calc_clip_s(self, img_feat, text_feat): - return self.clipscore_w * torch.relu((img_feat * text_feat).sum(dim=-1)) - - # @torch.no_grad() - def calc_refclip_s(self, img_feat=None, text_feat=None, ref_text_feat=None, ref_text_mask=None, clip_s=None): - - if clip_s is None: - clip_s = self.calc_clip_s(img_feat, text_feat) - - B, dim = img_feat.size() - - ref_text_feat = ref_text_feat.view(B, -1, dim) - - K = ref_text_feat.size(1) - - text_feat = text_feat.view(B, 1, dim).expand(-1, K, -1) - assert ref_text_feat.size() == text_feat.size( - ), (ref_text_feat.size(), text_feat.size()) - - ref_score = self.calc_clip_s(text_feat, ref_text_feat) - if ref_text_mask is not None: - if not isinstance(ref_text_mask, torch.Tensor): - ref_text_mask = torch.tensor( - ref_text_mask, dtype=ref_score.dtype, device=ref_score.device) - ref_score = ref_score.view(B, K) * ref_text_mask.view(B, K) - - ref_score = ref_score.view(B, K).max(dim=1).values - - assert clip_s.size() == (B,) - assert clip_s.size() == ref_score.size() - - # harmonic mean - refclip_s = 2 / (1 / clip_s + 1 / ref_score) - return refclip_s - - @torch.no_grad() - def forward(self, - images=None, text=None, - img_feat=None, text_feat=None, - ref_text=None, ref_text_feat=None, ref_text_mask=None, - prompt="A photo depicts", - mode=None): - if img_feat is None: - img_feat = self.image_extract(images) - img_feat = img_feat.view(-1, 512) - - B = img_feat.size(0) - - if text_feat is None: - text_feat = self.text_extract(text, prompt=prompt) - text_feat = text_feat.view(-1, 512) - - if mode is None: - mode = self.mode - assert mode in ['clip_s', 'refclip_s'] - - if mode == 'clip_s': - clip_s = self.calc_clip_s(img_feat, text_feat) - return clip_s - elif mode == 'refclip_s': - if ref_text_feat is None: - ref_text_feat = self.text_extract(ref_text, prompt=prompt) - ref_text_feat = ref_text_feat.view(-1, 512) - - refclip_s = self.calc_refclip_s( - img_feat, text_feat, ref_text_feat, ref_text_mask=ref_text_mask) - return refclip_s - - - def train_step(self, - images=None, text=None, - img_feat=None, text_feat=None, - neg_text=None, neg_text_feat=None, - # ref_text=None, ref_text_feat=None, ref_text_mask=None, - prompt="A photo depicts", - # return_loss=True, - **kwargs): - - if img_feat is None: - img_feat = self.image_extract(images) - img_feat = img_feat.view(-1, 512) - - B = img_feat.size(0) - - if text_feat is None: - text_feat = self.text_extract(text, prompt=prompt, proj_norm=False) - - text_cont_feat = self.clip_model.text_projection(text_feat) - text_cont_feat = text_cont_feat / text_cont_feat.norm(dim=-1, keepdim=True) - text_cont_feat = text_cont_feat.view(B, 512) - - # cosine similarity as logits - logit_scale = self.clip_model.logit_scale.exp() - logits_per_text = torch.matmul(text_cont_feat, img_feat.t()) * logit_scale - # logits_per_image = logits_per_text.T - - clip_loss = clip_loss_fn(logits_per_text) - - - # negative sampling - pos_text_feat = text_feat.view(B, 512) - neg_text_feat = self.text_extract(neg_text, prompt=prompt, proj_norm=False).view(B, 512) - - grammar_text_feat = torch.cat([pos_text_feat, neg_text_feat], dim=0) - - # 2B, 1 - grammar_text_logit = self.grammar_score_head(grammar_text_feat) - grammar_labels = torch.LongTensor([1] * B + [0] * B).to(grammar_text_logit.device).view(2 * B) - - grammar_loss = torch.nn.functional.cross_entropy(grammar_text_logit, grammar_labels) - - grammar_pred = grammar_text_logit.argmax(dim=1, keepdim=False) - grammar_pos_pred = grammar_pred[:B] - grammar_neg_pred = grammar_pred[B:] - # grammar_acc = (grammar_pred == grammar_labels).float().mean() - - out = { - 'clip_loss': clip_loss, - 'grammar_loss': grammar_loss, - 'img_feat': img_feat, - 'text_feat': text_cont_feat, - 'neg_text_feat': neg_text_feat, - 'grammar_pos_pred': grammar_pos_pred, - 'grammar_neg_pred': grammar_neg_pred, - } - - return out - -# contrastive loss function, adapted from -# https://sachinruk.github.io/blog/pytorch/pytorch%20lightning/loss%20function/gpu/2021/03/07/CLIP.html -def contrastive_loss(logits: torch.Tensor, dim: int) -> torch.Tensor: - neg_ce = torch.diag(nn.functional.log_softmax(logits, dim=dim)) - return -neg_ce.mean() - - -def clip_loss_fn(similarity: torch.Tensor) -> torch.Tensor: - caption_loss = contrastive_loss(similarity, dim=0) - image_loss = contrastive_loss(similarity, dim=1) - return (caption_loss + image_loss) / 2.0 - - - -# class CLIPScore(nn.Module): -# def __init__(self, clipscore_w=2.5, image_size=224, mode='clip_s'): -# super(CLIPScore, self).__init__() -# # from transformers import CLIPModel, CLIPTokenizer -# self.clip_model = CLIPModel.from_pretrained( -# 'openai/clip-vit-base-patch32') -# self.tokenizer = CLIPTokenizer.from_pretrained( -# 'openai/clip-vit-base-patch32') - -# self.clip_model.eval() - -# self.clipscore_w = clipscore_w - -# self.image_transform = self._transform(image_size) - -# self.mode = mode -# assert mode in ['clip_s', 'refclip_s'] - -# def _transform(self, n_px): -# return Compose([ -# Resize(n_px, interpolation=Image.BICUBIC), -# CenterCrop(n_px), -# lambda image: image.convert("RGB"), -# ToTensor(), -# Normalize((0.48145466, 0.4578275, 0.40821073), -# (0.26862954, 0.26130258, 0.27577711)), -# ]) - -# def load_image(self, image_path): -# image = Image.open(image_path) -# return image - -# @torch.no_grad() -# def image_extract(self, image): -# if isinstance(image, str): -# image = self.load_image(image) -# if not isinstance(image, torch.Tensor): -# image = self.image_transform(image) - -# img_tensor = image.view(-1, 3, 224, 224) -# device = next(self.clip_model.parameters()).device -# img_tensor = img_tensor.to(device) - -# clip_model = self.clip_model - -# img_feat = clip_model.vision_model(img_tensor).pooler_output -# img_feat = clip_model.visual_projection(img_feat) -# img_feat = img_feat / img_feat.norm(dim=-1, keepdim=True) - -# return img_feat - -# @torch.no_grad() -# def text_extract(self, text, prompt="A photo depicts"): -# if isinstance(text, str): -# text_batch = [" ".join([prompt, text])] -# else: -# text_batch = [" ".join([prompt, txt]) for txt in text] - -# input_text = text_batch - -# tokenized = self.tokenizer( -# input_text, return_tensors='pt', padding=True) - -# input_ids = tokenized.input_ids -# attention_mask = tokenized.attention_mask - -# clip_model = self.clip_model -# device = next(self.clip_model.parameters()).device -# input_ids = input_ids.to(device) -# attention_mask = attention_mask.to(device) - -# text_feat = clip_model.text_model(input_ids, attention_mask).pooler_output -# text_feat = clip_model.text_projection(text_feat) -# text_feat = text_feat / text_feat.norm(dim=-1, keepdim=True) - -# return text_feat - -# @torch.no_grad() -# def calc_clip_s(self, img_feat, text_feat): -# return self.clipscore_w * torch.relu((img_feat * text_feat).sum(dim=-1)) - -# @torch.no_grad() -# def calc_refclip_s(self, img_feat=None, text_feat=None, ref_text_feat=None, ref_text_mask=None, clip_s=None): - -# if clip_s is None: -# clip_s = self.calc_clip_s(img_feat, text_feat) - -# B, dim = img_feat.size() - -# ref_text_feat = ref_text_feat.view(B, -1, dim) - -# K = ref_text_feat.size(1) - -# text_feat = text_feat.view(B, 1, dim).expand(-1, K, -1) -# assert ref_text_feat.size() == text_feat.size(), (ref_text_feat.size(), text_feat.size()) - -# ref_score = self.calc_clip_s(text_feat, ref_text_feat) -# if ref_text_mask is not None: -# if not isinstance(ref_text_mask, torch.Tensor): -# ref_text_mask = torch.tensor(ref_text_mask, dtype=ref_score.dtype, device=ref_score.device) -# ref_score = ref_score.view(B, K) * ref_text_mask.view(B, K) - -# ref_score = ref_score.view(B, K).max(dim=1).values - -# assert clip_s.size() == (B,) -# assert clip_s.size() == ref_score.size() - -# # harmonic mean -# refclip_s = 2 / (1 / clip_s + 1 / ref_score) -# return refclip_s - - -# @torch.no_grad() -# def forward(self, -# images=None, text=None, -# img_feat=None, text_feat=None, -# ref_text=None, ref_text_feat=None, ref_text_mask=None, -# prompt="A photo depicts", -# mode=None): -# if img_feat is None: -# img_feat = self.image_extract(images) -# img_feat = img_feat.view(-1, 512) - -# if text_feat is None: -# text_feat = self.text_extract(text, prompt=prompt) -# text_feat = text_feat.view(-1, 512) - -# if mode is None: -# mode = self.mode -# assert mode in ['clip_s', 'refclip_s'] - -# if mode == 'clip_s': -# clip_s = self.calc_clip_s(img_feat, text_feat) -# return clip_s -# elif mode == 'refclip_s': -# if ref_text_feat is None: -# ref_text_feat = self.text_extract(ref_text, prompt=prompt) -# ref_text_feat = ref_text_feat.view(-1, 512) - -# refclip_s = self.calc_refclip_s(img_feat, text_feat, ref_text_feat, ref_text_mask=ref_text_mask) -# return refclip_s - diff --git a/captioning/utils/config.py b/captioning/utils/config.py deleted file mode 100644 index e42704dcba2fb2f751fec413551a5069e63f25c9..0000000000000000000000000000000000000000 --- a/captioning/utils/config.py +++ /dev/null @@ -1,153 +0,0 @@ -# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. -# Copy from fvcore - -import logging -import os -from typing import Any -import yaml -from yacs.config import CfgNode as _CfgNode - -import io as PathManager - -BASE_KEY = "_BASE_" - - -class CfgNode(_CfgNode): - """ - Our own extended version of :class:`yacs.config.CfgNode`. - It contains the following extra features: - - 1. The :meth:`merge_from_file` method supports the "_BASE_" key, - which allows the new CfgNode to inherit all the attributes from the - base configuration file. - 2. Keys that start with "COMPUTED_" are treated as insertion-only - "computed" attributes. They can be inserted regardless of whether - the CfgNode is frozen or not. - 3. With "allow_unsafe=True", it supports pyyaml tags that evaluate - expressions in config. See examples in - https://pyyaml.org/wiki/PyYAMLDocumentation#yaml-tags-and-python-types - Note that this may lead to arbitrary code execution: you must not - load a config file from untrusted sources before manually inspecting - the content of the file. - """ - - @staticmethod - def load_yaml_with_base(filename, allow_unsafe = False): - """ - Just like `yaml.load(open(filename))`, but inherit attributes from its - `_BASE_`. - - Args: - filename (str): the file name of the current config. Will be used to - find the base config file. - allow_unsafe (bool): whether to allow loading the config file with - `yaml.unsafe_load`. - - Returns: - (dict): the loaded yaml - """ - with PathManager.open(filename, "r") as f: - try: - cfg = yaml.safe_load(f) - except yaml.constructor.ConstructorError: - if not allow_unsafe: - raise - logger = logging.getLogger(__name__) - logger.warning( - "Loading config {} with yaml.unsafe_load. Your machine may " - "be at risk if the file contains malicious content.".format( - filename - ) - ) - f.close() - with open(filename, "r") as f: - cfg = yaml.unsafe_load(f) - - def merge_a_into_b(a, b): - # merge dict a into dict b. values in a will overwrite b. - for k, v in a.items(): - if isinstance(v, dict) and k in b: - assert isinstance( - b[k], dict - ), "Cannot inherit key '{}' from base!".format(k) - merge_a_into_b(v, b[k]) - else: - b[k] = v - - if BASE_KEY in cfg: - base_cfg_file = cfg[BASE_KEY] - if base_cfg_file.startswith("~"): - base_cfg_file = os.path.expanduser(base_cfg_file) - if not any( - map(base_cfg_file.startswith, ["/", "https://", "http://"]) - ): - # the path to base cfg is relative to the config file itself. - base_cfg_file = os.path.join( - os.path.dirname(filename), base_cfg_file - ) - base_cfg = CfgNode.load_yaml_with_base( - base_cfg_file, allow_unsafe=allow_unsafe - ) - del cfg[BASE_KEY] - - merge_a_into_b(cfg, base_cfg) - return base_cfg - return cfg - - def merge_from_file(self, cfg_filename, allow_unsafe = False): - """ - Merge configs from a given yaml file. - - Args: - cfg_filename: the file name of the yaml config. - allow_unsafe: whether to allow loading the config file with - `yaml.unsafe_load`. - """ - loaded_cfg = CfgNode.load_yaml_with_base( - cfg_filename, allow_unsafe=allow_unsafe - ) - loaded_cfg = type(self)(loaded_cfg) - self.merge_from_other_cfg(loaded_cfg) - - # Forward the following calls to base, but with a check on the BASE_KEY. - def merge_from_other_cfg(self, cfg_other): - """ - Args: - cfg_other (CfgNode): configs to merge from. - """ - assert ( - BASE_KEY not in cfg_other - ), "The reserved key '{}' can only be used in files!".format(BASE_KEY) - return super().merge_from_other_cfg(cfg_other) - - def merge_from_list(self, cfg_list): - """ - Args: - cfg_list (list): list of configs to merge from. - """ - keys = set(cfg_list[0::2]) - assert ( - BASE_KEY not in keys - ), "The reserved key '{}' can only be used in files!".format(BASE_KEY) - return super().merge_from_list(cfg_list) - - def __setattr__(self, name, val): - if name.startswith("COMPUTED_"): - if name in self: - old_val = self[name] - if old_val == val: - return - raise KeyError( - "Computed attributed '{}' already exists " - "with a different value! old={}, new={}.".format( - name, old_val, val - ) - ) - self[name] = val - else: - super().__setattr__(name, val) - - -if __name__ == '__main__': - cfg = CfgNode.load_yaml_with_base('configs/updown_long.yml') - print(cfg) \ No newline at end of file diff --git a/captioning/utils/dist_utils.py b/captioning/utils/dist_utils.py deleted file mode 100644 index 53a7c462570edb8f381c65fabf60c729f1607f41..0000000000000000000000000000000000000000 --- a/captioning/utils/dist_utils.py +++ /dev/null @@ -1,305 +0,0 @@ -# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved -""" -This file contains primitives for multi-gpu communication. -This is useful when doing distributed training. -""" - -import functools -import logging -import numpy as np -import pickle -import torch -import torch.distributed as dist - -import torch - -_LOCAL_PROCESS_GROUP = None -""" -A torch process group which only includes processes that on the same machine as the current process. -This variable is set when processes are spawned by `launch()` in "engine/launch.py". -""" - - -def get_world_size() -> int: - if not dist.is_available(): - return 1 - if not dist.is_initialized(): - return 1 - return dist.get_world_size() - - -def get_rank() -> int: - if not dist.is_available(): - return 0 - if not dist.is_initialized(): - return 0 - return dist.get_rank() - - -def get_local_rank() -> int: - """ - Returns: - The rank of the current process within the local (per-machine) process group. - """ - if not dist.is_available(): - return 0 - if not dist.is_initialized(): - return 0 - assert _LOCAL_PROCESS_GROUP is not None - return dist.get_rank(group=_LOCAL_PROCESS_GROUP) - - -def get_local_size() -> int: - """ - Returns: - The size of the per-machine process group, - i.e. the number of processes per machine. - """ - if not dist.is_available(): - return 1 - if not dist.is_initialized(): - return 1 - return dist.get_world_size(group=_LOCAL_PROCESS_GROUP) - - -def is_main_process() -> bool: - return get_rank() == 0 - - -def synchronize(): - """ - Helper function to synchronize (barrier) among all processes when - using distributed training - """ - if not dist.is_available(): - return - if not dist.is_initialized(): - return - world_size = dist.get_world_size() - if world_size == 1: - return - dist.barrier() - - -@functools.lru_cache() -def _get_global_gloo_group(): - """ - Return a process group based on gloo backend, containing all the ranks - The result is cached. - """ - if dist.get_backend() == "nccl": - return dist.new_group(backend="gloo") - else: - return dist.group.WORLD - - -def _serialize_to_tensor(data, group): - backend = dist.get_backend(group) - assert backend in ["gloo", "nccl"] - device = torch.device("cpu" if backend == "gloo" else "cuda") - - buffer = pickle.dumps(data) - if len(buffer) > 1024 ** 3: - logger = logging.getLogger(__name__) - logger.warning( - "Rank {} trying to all-gather {:.2f} GB of data on device {}".format( - get_rank(), len(buffer) / (1024 ** 3), device - ) - ) - storage = torch.ByteStorage.from_buffer(buffer) - tensor = torch.ByteTensor(storage).to(device=device) - return tensor - - -def _pad_to_largest_tensor(tensor, group): - """ - Returns: - list[int]: size of the tensor, on each rank - Tensor: padded tensor that has the max size - """ - world_size = dist.get_world_size(group=group) - assert ( - world_size >= 1 - ), "comm.gather/all_gather must be called from ranks within the given group!" - local_size = torch.tensor( - [tensor.numel()], dtype=torch.int64, device=tensor.device) - size_list = [ - torch.zeros([1], dtype=torch.int64, device=tensor.device) - for _ in range(world_size) - ] - dist.all_gather(size_list, local_size, group=group) - size_list = [int(size.item()) for size in size_list] - - max_size = max(size_list) - - # we pad the tensor because torch all_gather does not support - # gathering tensors of different shapes - if local_size != max_size: - padding = torch.zeros( - (max_size - local_size,), dtype=torch.uint8, device=tensor.device - ) - tensor = torch.cat((tensor, padding), dim=0) - return size_list, tensor - - -def all_gather(data, group=None): - """ - Run all_gather on arbitrary picklable data (not necessarily tensors). - Args: - data: any picklable object - group: a torch process group. By default, will use a group which - contains all ranks on gloo backend. - Returns: - list[data]: list of data gathered from each rank - """ - if get_world_size() == 1: - return [data] - if group is None: - group = _get_global_gloo_group() - if dist.get_world_size(group) == 1: - return [data] - - tensor = _serialize_to_tensor(data, group) - - size_list, tensor = _pad_to_largest_tensor(tensor, group) - max_size = max(size_list) - - # receiving Tensor from all ranks - tensor_list = [ - torch.empty((max_size,), dtype=torch.uint8, device=tensor.device) - for _ in size_list - ] - dist.all_gather(tensor_list, tensor, group=group) - - data_list = [] - for size, tensor in zip(size_list, tensor_list): - buffer = tensor.cpu().numpy().tobytes()[:size] - data_list.append(pickle.loads(buffer)) - - return data_list - - -def gather(data, dst=0, group=None): - """ - Run gather on arbitrary picklable data (not necessarily tensors). - Args: - data: any picklable object - dst (int): destination rank - group: a torch process group. By default, will use a group which - contains all ranks on gloo backend. - Returns: - list[data]: on dst, a list of data gathered from each rank. Otherwise, - an empty list. - """ - if get_world_size() == 1: - return [data] - if group is None: - group = _get_global_gloo_group() - if dist.get_world_size(group=group) == 1: - return [data] - rank = dist.get_rank(group=group) - - tensor = _serialize_to_tensor(data, group) - size_list, tensor = _pad_to_largest_tensor(tensor, group) - - # receiving Tensor from all ranks - if rank == dst: - max_size = max(size_list) - tensor_list = [ - torch.empty((max_size,), dtype=torch.uint8, device=tensor.device) - for _ in size_list - ] - dist.gather(tensor, tensor_list, dst=dst, group=group) - - data_list = [] - for size, tensor in zip(size_list, tensor_list): - buffer = tensor.cpu().numpy().tobytes()[:size] - data_list.append(pickle.loads(buffer)) - return data_list - else: - dist.gather(tensor, [], dst=dst, group=group) - return [] - - -def shared_random_seed(): - """ - Returns: - int: a random number that is the same across all workers. - If workers need a shared RNG, they can use this shared seed to - create one. - All workers must call this function, otherwise it will deadlock. - """ - ints = np.random.randint(2 ** 31) - all_ints = all_gather(ints) - return all_ints[0] - - -# def reduce_dict(input_dict, average=True): -# """ -# Reduce the values in the dictionary from all processes so that process with rank -# 0 has the reduced results. -# Args: -# input_dict (dict): inputs to be reduced. All the values must be scalar CUDA Tensor. -# average (bool): whether to do average or sum -# Returns: -# a dict with the same keys as input_dict, after reduction. -# """ -# world_size = get_world_size() -# if world_size < 2: -# return input_dict -# with torch.no_grad(): -# names = [] -# values = [] -# # sort the keys so that they are consistent across processes -# for k in sorted(input_dict.keys()): -# names.append(k) -# values.append(input_dict[k]) -# values = torch.stack(values, dim=0) -# dist.reduce(values, dst=0) -# if dist.get_rank() == 0 and average: -# # only main process gets accumulated, so only divide by -# # world_size in this case -# values /= world_size -# reduced_dict = {k: v for k, v in zip(names, values)} -# return reduced_dict - - -def reduce_dict(input_dict, average=True): - """ - Reduce the values in the dictionary from all processes so that process with rank - 0 has the reduced results. - Args: - input_dict (dict): inputs to be reduced. (values not necessarily tensors). - average (bool): whether to do average or sum - Returns: - a dict with the same keys as input_dict, after reduction. - """ - - world_size = get_world_size() - if world_size < 2: - return input_dict - - with torch.no_grad(): - - # Convert to CUDA Tensor for dist.reduce() - input_dict_cuda_vals = {} - for k, v in input_dict.items(): - if type(v) == torch.Tensor: - input_dict_cuda_vals[k] = v.to('cuda') - else: - input_dict_cuda_vals[k] = torch.tensor(v, device='cuda') - - names = [] - values = [] - for k, v in sorted(input_dict_cuda_vals.items()): - names.append(k) - values.append(v) - values = torch.stack(values, dim=0) - dist.reduce(values, dst=0) # reduce to gpu 0 - - if dist.get_rank() == 0 and average: - # only main process gets accumulated, so only divide by - # world_size in this case - values /= world_size - reduced_dict = {k: v for k, v in zip(names, values)} - return reduced_dict diff --git a/captioning/utils/div_utils.py b/captioning/utils/div_utils.py deleted file mode 100644 index a757eb7b2184767f8ea2351b30cce6601a45be78..0000000000000000000000000000000000000000 --- a/captioning/utils/div_utils.py +++ /dev/null @@ -1,38 +0,0 @@ -from random import uniform -import numpy as np -from collections import OrderedDict, defaultdict -from itertools import tee -import time - -# ----------------------------------------------- -def find_ngrams(input_list, n): - return zip(*[input_list[i:] for i in range(n)]) - -def compute_div_n(caps,n=1): - aggr_div = [] - for k in caps: - all_ngrams = set() - lenT = 0. - for c in caps[k]: - tkns = c.split() - lenT += len(tkns) - ng = find_ngrams(tkns, n) - all_ngrams.update(ng) - aggr_div.append(float(len(all_ngrams))/ (1e-6 + float(lenT))) - return np.array(aggr_div).mean(), np.array(aggr_div) - -def compute_global_div_n(caps,n=1): - aggr_div = [] - all_ngrams = set() - lenT = 0. - for k in caps: - for c in caps[k]: - tkns = c.split() - lenT += len(tkns) - ng = find_ngrams(tkns, n) - all_ngrams.update(ng) - if n == 1: - aggr_div.append(float(len(all_ngrams))) - else: - aggr_div.append(float(len(all_ngrams))/ (1e-6 + float(lenT))) - return aggr_div[0], np.repeat(np.array(aggr_div),len(caps)) \ No newline at end of file diff --git a/captioning/utils/eval_multi.py b/captioning/utils/eval_multi.py deleted file mode 100644 index 83907410b806a50002aa32db289ca86cff72f45d..0000000000000000000000000000000000000000 --- a/captioning/utils/eval_multi.py +++ /dev/null @@ -1,218 +0,0 @@ -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import torch -import torch.nn as nn - -import numpy as np -import json -from json import encoder -import random -import string -import time -import os -import sys -from . import misc as utils -from eval_utils import getCOCO - -from .div_utils import compute_div_n, compute_global_div_n - -import sys -try: - sys.path.append("coco-caption") - annFile = 'coco-caption/annotations/captions_val2014.json' - from pycocotools.coco import COCO - from pycocoevalcap.eval import COCOEvalCap - from pycocoevalcap.eval_spice import COCOEvalCapSpice - from pycocoevalcap.tokenizer.ptbtokenizer import PTBTokenizer - from pycocoevalcap.bleu.bleu import Bleu - sys.path.append("cider") - from pyciderevalcap.cider.cider import Cider -except: - print('Warning: requirements for eval_multi not satisfied') - - -def eval_allspice(dataset, preds_n, model_id, split): - coco = getCOCO(dataset) - valids = coco.getImgIds() - - capsById = {} - for d in preds_n: - capsById[d['image_id']] = capsById.get(d['image_id'], []) + [d] - - # filter results to only those in MSCOCO validation set (will be about a third) - preds_filt_n = [p for p in preds_n if p['image_id'] in valids] - print('using %d/%d predictions_n' % (len(preds_filt_n), len(preds_n))) - cache_path_n = os.path.join('eval_results/', model_id + '_' + split + '_n.json') - json.dump(preds_filt_n, open(cache_path_n, 'w')) # serialize to temporary json file. Sigh, COCO API... - - # Eval AllSPICE - cocoRes_n = coco.loadRes(cache_path_n) - cocoEvalAllSPICE = COCOEvalCapSpice(coco, cocoRes_n) - cocoEvalAllSPICE.params['image_id'] = cocoRes_n.getImgIds() - cocoEvalAllSPICE.evaluate() - - out = {} - for metric, score in cocoEvalAllSPICE.eval.items(): - out['All'+metric] = score - - imgToEvalAllSPICE = cocoEvalAllSPICE.imgToEval - # collect SPICE_sub_score - for k in list(imgToEvalAllSPICE.values())[0]['SPICE'].keys(): - if k != 'All': - out['AllSPICE_'+k] = np.array([v['SPICE'][k]['f'] for v in imgToEvalAllSPICE.values()]) - out['AllSPICE_'+k] = (out['AllSPICE_'+k][out['AllSPICE_'+k]==out['AllSPICE_'+k]]).mean() - for p in preds_filt_n: - image_id, caption = p['image_id'], p['caption'] - imgToEvalAllSPICE[image_id]['caption'] = capsById[image_id] - return {'overall': out, 'imgToEvalAllSPICE': imgToEvalAllSPICE} - -def eval_oracle(dataset, preds_n, model_id, split): - cache_path = os.path.join('eval_results/', model_id + '_' + split + '_n.json') - - coco = getCOCO(dataset) - valids = coco.getImgIds() - - capsById = {} - for d in preds_n: - capsById[d['image_id']] = capsById.get(d['image_id'], []) + [d] - - sample_n = capsById[list(capsById.keys())[0]] - for i in range(len(capsById[list(capsById.keys())[0]])): - preds = [_[i] for _ in capsById.values()] - - json.dump(preds, open(cache_path, 'w')) # serialize to temporary json file. Sigh, COCO API... - - cocoRes = coco.loadRes(cache_path) - cocoEval = COCOEvalCap(coco, cocoRes) - cocoEval.params['image_id'] = cocoRes.getImgIds() - cocoEval.evaluate() - - imgToEval = cocoEval.imgToEval - for img_id in capsById.keys(): - tmp = imgToEval[img_id] - for k in tmp['SPICE'].keys(): - if k != 'All': - tmp['SPICE_'+k] = tmp['SPICE'][k]['f'] - if tmp['SPICE_'+k] != tmp['SPICE_'+k]: # nan - tmp['SPICE_'+k] = -100 - tmp['SPICE'] = tmp['SPICE']['All']['f'] - if tmp['SPICE'] != tmp['SPICE']: tmp['SPICE'] = -100 - capsById[img_id][i]['scores'] = imgToEval[img_id] - - out = {'overall': {}, 'ImgToEval': {}} - for img_id in capsById.keys(): - out['ImgToEval'][img_id] = {} - for metric in capsById[img_id][0]['scores'].keys(): - if metric == 'image_id': continue - out['ImgToEval'][img_id]['oracle_'+metric] = max([_['scores'][metric] for _ in capsById[img_id]]) - out['ImgToEval'][img_id]['avg_'+metric] = sum([_['scores'][metric] for _ in capsById[img_id]]) / len(capsById[img_id]) - out['ImgToEval'][img_id]['captions'] = capsById[img_id] - for metric in list(out['ImgToEval'].values())[0].keys(): - if metric == 'captions': - continue - tmp = np.array([_[metric] for _ in out['ImgToEval'].values()]) - tmp = tmp[tmp!=-100] - out['overall'][metric] = tmp.mean() - - return out - -def eval_div_stats(dataset, preds_n, model_id, split): - tokenizer = PTBTokenizer() - - capsById = {} - for i, d in enumerate(preds_n): - d['id'] = i - capsById[d['image_id']] = capsById.get(d['image_id'], []) + [d] - - n_caps_perimg = len(capsById[list(capsById.keys())[0]]) - print(n_caps_perimg) - _capsById = capsById # save the untokenized version - capsById = tokenizer.tokenize(capsById) - - div_1, adiv_1 = compute_div_n(capsById,1) - div_2, adiv_2 = compute_div_n(capsById,2) - - globdiv_1, _= compute_global_div_n(capsById,1) - - print('Diversity Statistics are as follows: \n Div1: %.2f, Div2: %.2f, gDiv1: %d\n'%(div_1,div_2, globdiv_1)) - - # compute mbleu - scorer = Bleu(4) - all_scrs = [] - scrperimg = np.zeros((n_caps_perimg, len(capsById))) - - for i in range(n_caps_perimg): - tempRefsById = {} - candsById = {} - for k in capsById: - tempRefsById[k] = capsById[k][:i] + capsById[k][i+1:] - candsById[k] = [capsById[k][i]] - - score, scores = scorer.compute_score(tempRefsById, candsById) - all_scrs.append(score) - scrperimg[i,:] = scores[1] - - all_scrs = np.array(all_scrs) - - out = {} - out['overall'] = {'Div1': div_1, 'Div2': div_2, 'gDiv1': globdiv_1} - for k, score in zip(range(4), all_scrs.mean(axis=0).tolist()): - out['overall'].update({'mBLeu_%d'%(k+1): score}) - imgToEval = {} - for i,imgid in enumerate(capsById.keys()): - imgToEval[imgid] = {'mBleu_2' : scrperimg[:,i].mean()} - imgToEval[imgid]['individuals'] = [] - for j, d in enumerate(_capsById[imgid]): - imgToEval[imgid]['individuals'].append(preds_n[d['id']]) - imgToEval[imgid]['individuals'][-1]['mBleu_2'] = scrperimg[j,i] - out['ImgToEval'] = imgToEval - - print('Mean mutual Bleu scores on this set is:\nmBLeu_1, mBLeu_2, mBLeu_3, mBLeu_4') - print(all_scrs.mean(axis=0)) - - return out - -def eval_self_cider(dataset, preds_n, model_id, split): - cache_path = os.path.join('eval_results/', model_id + '_' + split + '_n.json') - - coco = getCOCO(dataset) - valids = coco.getImgIds() - - # Get Cider_scorer - Cider_scorer = Cider(df='corpus') - - tokenizer = PTBTokenizer() - gts = {} - for imgId in valids: - gts[imgId] = coco.imgToAnns[imgId] - gts = tokenizer.tokenize(gts) - - for imgId in valids: - Cider_scorer.cider_scorer += (None, gts[imgId]) - Cider_scorer.cider_scorer.compute_doc_freq() - Cider_scorer.cider_scorer.ref_len = np.log(float(len(Cider_scorer.cider_scorer.crefs))) - - # Prepare captions - capsById = {} - for d in preds_n: - capsById[d['image_id']] = capsById.get(d['image_id'], []) + [d] - - capsById = tokenizer.tokenize(capsById) - imgIds = list(capsById.keys()) - scores = Cider_scorer.my_self_cider([capsById[_] for _ in imgIds]) - - def get_div(eigvals): - eigvals = np.clip(eigvals, 0, None) - return -np.log(np.sqrt(eigvals[-1]) / (np.sqrt(eigvals).sum())) / np.log(len(eigvals)) - sc_scores = [get_div(np.linalg.eigvalsh(_/10)) for _ in scores] - score = np.mean(np.array(sc_scores)) - - imgToEval = {} - for i, image_id in enumerate(imgIds): - imgToEval[image_id] = {'self_cider': sc_scores[i], 'self_cider_mat': scores[i].tolist()} - return {'overall': {'self_cider': score}, 'imgToEval': imgToEval} - - - return score diff --git a/captioning/utils/eval_utils.py b/captioning/utils/eval_utils.py deleted file mode 100644 index c4bc7f4471e6d3e1fcc2f80af6f47bfec5d920a1..0000000000000000000000000000000000000000 --- a/captioning/utils/eval_utils.py +++ /dev/null @@ -1,281 +0,0 @@ -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import torch -import torch.nn as nn -import torch.nn.functional as F - -import numpy as np -import json -from json import encoder -import random -import string -import time -import os -import sys -from . import misc as utils - -# load coco-caption if available -try: - sys.path.append("coco-caption") - from pycocotools.coco import COCO - from pycocoevalcap.eval import COCOEvalCap -except: - print('Warning: coco-caption not available') - -bad_endings = ['a','an','the','in','for','at','of','with','before','after','on','upon','near','to','is','are','am'] -bad_endings += ['the'] - - -def count_bad(sen): - sen = sen.split(' ') - if sen[-1] in bad_endings: - return 1 - else: - return 0 - - -def getCOCO(dataset): - if 'coco' in dataset: - annFile = 'coco-caption/annotations/captions_val2014.json' - elif 'flickr30k' in dataset or 'f30k' in dataset: - annFile = 'data/f30k_captions4eval.json' - return COCO(annFile) - - -def language_eval(dataset, preds, preds_n, eval_kwargs, split): - model_id = eval_kwargs['id'] - eval_oracle = eval_kwargs.get('eval_oracle', 0) - - # create output dictionary - out = {} - - if len(preds_n) > 0: - # vocab size and novel sentences - if 'coco' in dataset: - dataset_file = 'data/dataset_coco.json' - elif 'flickr30k' in dataset or 'f30k' in dataset: - dataset_file = 'data/dataset_flickr30k.json' - training_sentences = set([' '.join(__['tokens']) for _ in json.load(open(dataset_file))['images'] if not _['split'] in ['val', 'test'] for __ in _['sentences']]) - generated_sentences = set([_['caption'] for _ in preds_n]) - novels = generated_sentences - training_sentences - out['novel_sentences'] = float(len(novels)) / len(preds_n) - tmp = [_.split() for _ in generated_sentences] - words = [] - for _ in tmp: - words += _ - out['vocab_size'] = len(set(words)) - - # encoder.FLOAT_REPR = lambda o: format(o, '.3f') - - cache_path = os.path.join('eval_results/', '.cache_'+ model_id + '_' + split + '.json') - - coco = getCOCO(dataset) - valids = coco.getImgIds() - - # filter results to only those in MSCOCO validation set - preds_filt = [p for p in preds if p['image_id'] in valids] - mean_perplexity = sum([_['perplexity'] for _ in preds_filt]) / len(preds_filt) - mean_entropy = sum([_['entropy'] for _ in preds_filt]) / len(preds_filt) - print('using %d/%d predictions' % (len(preds_filt), len(preds))) - json.dump(preds_filt, open(cache_path, 'w')) # serialize to temporary json file. Sigh, COCO API... - - cocoRes = coco.loadRes(cache_path) - cocoEval = COCOEvalCap(coco, cocoRes) - cocoEval.params['image_id'] = cocoRes.getImgIds() - cocoEval.evaluate() - - for metric, score in cocoEval.eval.items(): - out[metric] = score - # Add mean perplexity - out['perplexity'] = mean_perplexity - out['entropy'] = mean_entropy - - imgToEval = cocoEval.imgToEval - for k in list(imgToEval.values())[0]['SPICE'].keys(): - if k != 'All': - out['SPICE_'+k] = np.array([v['SPICE'][k]['f'] for v in imgToEval.values()]) - out['SPICE_'+k] = (out['SPICE_'+k][out['SPICE_'+k]==out['SPICE_'+k]]).mean() - for p in preds_filt: - image_id, caption = p['image_id'], p['caption'] - imgToEval[image_id]['caption'] = caption - - if len(preds_n) > 0: - from . import eval_multi - cache_path_n = os.path.join('eval_results/', '.cache_'+ model_id + '_' + split + '_n.json') - allspice = eval_multi.eval_allspice(dataset, preds_n, model_id, split) - out.update(allspice['overall']) - div_stats = eval_multi.eval_div_stats(dataset, preds_n, model_id, split) - out.update(div_stats['overall']) - if eval_oracle: - oracle = eval_multi.eval_oracle(dataset, preds_n, model_id, split) - out.update(oracle['overall']) - else: - oracle = None - self_cider = eval_multi.eval_self_cider(dataset, preds_n, model_id, split) - out.update(self_cider['overall']) - with open(cache_path_n, 'w') as outfile: - json.dump({'allspice': allspice, 'div_stats': div_stats, 'oracle': oracle, 'self_cider': self_cider}, outfile) - - out['bad_count_rate'] = sum([count_bad(_['caption']) for _ in preds_filt]) / float(len(preds_filt)) - outfile_path = os.path.join('eval_results/', model_id + '_' + split + '.json') - with open(outfile_path, 'w') as outfile: - json.dump({'overall': out, 'imgToEval': imgToEval}, outfile) - - return out - -def eval_split(model, crit, loader, eval_kwargs={}): - verbose = eval_kwargs.get('verbose', True) - verbose_beam = eval_kwargs.get('verbose_beam', 0) - verbose_loss = eval_kwargs.get('verbose_loss', 1) - num_images = eval_kwargs.get('num_images', eval_kwargs.get('val_images_use', -1)) - split = eval_kwargs.get('split', 'val') - lang_eval = eval_kwargs.get('language_eval', 0) - dataset = eval_kwargs.get('dataset', 'coco') - beam_size = eval_kwargs.get('beam_size', 1) - sample_n = eval_kwargs.get('sample_n', 1) - remove_bad_endings = eval_kwargs.get('remove_bad_endings', 0) - os.environ["REMOVE_BAD_ENDINGS"] = str(remove_bad_endings) # Use this nasty way to make other code clean since it's a global configuration - device = eval_kwargs.get('device', 'cuda') - - # Make sure in the evaluation mode - model.eval() - - loader.reset_iterator(split) - - n = 0 - loss = 0 - loss_sum = 0 - loss_evals = 1e-8 - predictions = [] - n_predictions = [] # when sample_n > 1 - while True: - data = loader.get_batch(split) - n = n + len(data['infos']) - - tmp = [data['fc_feats'], data['att_feats'], data['labels'], data['masks'], data['att_masks']] - tmp = [_.to(device) if _ is not None else _ for _ in tmp] - fc_feats, att_feats, labels, masks, att_masks = tmp - if labels is not None and verbose_loss: - # forward the model to get loss - with torch.no_grad(): - loss = crit(model(fc_feats, att_feats, labels[..., :-1], att_masks), labels[..., 1:], masks[..., 1:]).item() - loss_sum = loss_sum + loss - loss_evals = loss_evals + 1 - - # forward the model to also get generated samples for each image - with torch.no_grad(): - tmp_eval_kwargs = eval_kwargs.copy() - tmp_eval_kwargs.update({'sample_n': 1}) - seq, seq_logprobs = model(fc_feats, att_feats, att_masks, opt=tmp_eval_kwargs, mode='sample') - seq = seq.data - entropy = - (F.softmax(seq_logprobs, dim=2) * seq_logprobs).sum(2).sum(1) / ((seq>0).to(seq_logprobs).sum(1)+1) - perplexity = - seq_logprobs.gather(2, seq.unsqueeze(2)).squeeze(2).sum(1) / ((seq>0).to(seq_logprobs).sum(1)+1) - - # Print beam search - if beam_size > 1 and verbose_beam: - for i in range(fc_feats.shape[0]): - print('\n'.join([utils.decode_sequence(model.vocab, _['seq'].unsqueeze(0))[0] for _ in model.done_beams[i]])) - print('--' * 10) - sents = utils.decode_sequence(model.vocab, seq) - - for k, sent in enumerate(sents): - entry = {'image_id': data['infos'][k]['id'], 'caption': sent, 'perplexity': perplexity[k].item(), 'entropy': entropy[k].item()} - if eval_kwargs.get('dump_path', 0) == 1: - entry['file_name'] = data['infos'][k]['file_path'] - predictions.append(entry) - if eval_kwargs.get('dump_images', 0) == 1: - # dump the raw image to vis/ folder - cmd = 'cp "' + os.path.join(eval_kwargs['image_root'], data['infos'][k]['file_path']) + '" vis/imgs/img' + str(len(predictions)) + '.jpg' # bit gross - print(cmd) - os.system(cmd) - - if verbose: - print('image %s: %s' %(entry['image_id'], entry['caption'])) - - if sample_n > 1: - eval_split_n(model, n_predictions, [fc_feats, att_feats, att_masks, data], eval_kwargs) - - # ix0 = data['bounds']['it_pos_now'] - ix1 = data['bounds']['it_max'] - if num_images != -1: - ix1 = min(ix1, num_images) - else: - num_images = ix1 - for i in range(n - ix1): - predictions.pop() - - if verbose: - print('evaluating validation preformance... %d/%d (%f)' %(n, ix1, loss)) - - if num_images >= 0 and n >= num_images: - break - - lang_stats = None - if len(n_predictions) > 0 and 'perplexity' in n_predictions[0]: - n_predictions = sorted(n_predictions, key=lambda x: x['perplexity']) - if not os.path.isdir('eval_results'): - os.mkdir('eval_results') - torch.save((predictions, n_predictions), os.path.join('eval_results/', '.saved_pred_'+ eval_kwargs['id'] + '_' + split + '.pth')) - if lang_eval == 1: - lang_stats = language_eval(dataset, predictions, n_predictions, eval_kwargs, split) - - # Switch back to training mode - model.train() - return loss_sum/loss_evals, predictions, lang_stats - - -# Only run when sample_n > 0 -def eval_split_n(model, n_predictions, input_data, eval_kwargs={}): - verbose = eval_kwargs.get('verbose', True) - beam_size = eval_kwargs.get('beam_size', 1) - sample_n = eval_kwargs.get('sample_n', 1) - sample_n_method = eval_kwargs.get('sample_n_method', 'sample') - - fc_feats, att_feats, att_masks, data = input_data - - tmp_eval_kwargs = eval_kwargs.copy() - if sample_n_method == 'bs': - # case 1 sample_n == beam size - tmp_eval_kwargs.update({'sample_n': 1, 'beam_size': sample_n, 'group_size': 1}) # randomness from softmax - with torch.no_grad(): - model(fc_feats, att_feats, att_masks, opt=tmp_eval_kwargs, mode='sample') - for k in range(fc_feats.shape[0]): - _sents = utils.decode_sequence(model.vocab, torch.stack([model.done_beams[k][_]['seq'] for _ in range(sample_n)])) - for sent in _sents: - entry = {'image_id': data['infos'][k]['id'], 'caption': sent} - n_predictions.append(entry) - # case 2 sample / gumbel / topk sampling/ nucleus sampling - elif sample_n_method == 'sample' or \ - sample_n_method == 'gumbel' or \ - sample_n_method.startswith('top'): - tmp_eval_kwargs.update({'sample_n': sample_n, 'sample_method': sample_n_method, 'beam_size': 1}) # randomness from sample - with torch.no_grad(): - _seq, _sampleLogprobs = model(fc_feats, att_feats, att_masks, opt=tmp_eval_kwargs, mode='sample') - _sents = utils.decode_sequence(model.vocab, _seq) - _perplexity = - _sampleLogprobs.gather(2, _seq.unsqueeze(2)).squeeze(2).sum(1) / ((_seq>0).to(_sampleLogprobs).sum(1)+1) - for k, sent in enumerate(_sents): - entry = {'image_id': data['infos'][k // sample_n]['id'], 'caption': sent, 'perplexity': _perplexity[k].item()} - n_predictions.append(entry) - elif sample_n_method == 'dbs': - # Use diverse beam search - tmp_eval_kwargs.update({'beam_size': sample_n * beam_size, 'group_size': sample_n}) # randomness from softmax - with torch.no_grad(): - model(fc_feats, att_feats, att_masks, opt=tmp_eval_kwargs, mode='sample') - for k in range(loader.batch_size): - _sents = utils.decode_sequence(model.vocab, torch.stack([model.done_beams[k][_]['seq'] for _ in range(0, sample_n*beam_size, beam_size)])) - for sent in _sents: - entry = {'image_id': data['infos'][k]['id'], 'caption': sent} - n_predictions.append(entry) - else: - tmp_eval_kwargs.update({'sample_method': sample_n_method[1:], 'group_size': sample_n, 'beam_size':1}) # randomness from softmax - with torch.no_grad(): - _seq, _sampleLogprobs = model(fc_feats, att_feats, att_masks, opt=tmp_eval_kwargs, mode='sample') - _sents = utils.decode_sequence(model.vocab, _seq) - for k, sent in enumerate(_sents): - entry = {'image_id': data['infos'][k // sample_n]['id'], 'caption': sent} - n_predictions.append(entry) - if verbose: - for entry in sorted(n_predictions[-fc_feats.shape[0] * sample_n:], key=lambda x: x['image_id']): - print('image %s: %s' %(entry['image_id'], entry['caption'])) \ No newline at end of file diff --git a/captioning/utils/misc.py b/captioning/utils/misc.py deleted file mode 100644 index 3edcc1b51c99e66c568fa5d3d93f131911096489..0000000000000000000000000000000000000000 --- a/captioning/utils/misc.py +++ /dev/null @@ -1,251 +0,0 @@ -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import collections -import torch -import torch.nn as nn -import numpy as np -import torch.optim as optim -import os - -import torch.nn.functional as F - -import six -from six.moves import cPickle - -bad_endings = ['with','in','on','of','a','at','to','for','an','this','his','her','that'] -bad_endings += ['the'] - - -def pickle_load(f): - """ Load a pickle. - Parameters - ---------- - f: file-like object - """ - if six.PY3: - return cPickle.load(f, encoding='latin-1') - else: - return cPickle.load(f) - - -def pickle_dump(obj, f): - """ Dump a pickle. - Parameters - ---------- - obj: pickled object - f: file-like object - """ - if six.PY3: - return cPickle.dump(obj, f, protocol=2) - else: - return cPickle.dump(obj, f) - - -# modified from https://github.com/facebookresearch/detectron2/blob/master/detectron2/utils/comm.py -def serialize_to_tensor(data): - device = torch.device("cpu") - - buffer = cPickle.dumps(data) - storage = torch.ByteStorage.from_buffer(buffer) - tensor = torch.ByteTensor(storage).to(device=device) - return tensor - - -def deserialize(tensor): - buffer = tensor.cpu().numpy().tobytes() - return cPickle.loads(buffer) - - -# Input: seq, N*D numpy array, with element 0 .. vocab_size. 0 is END token. -def decode_sequence(ix_to_word, seq): - # N, D = seq.size() - N, D = seq.shape - out = [] - for i in range(N): - txt = '' - for j in range(D): - ix = seq[i,j] - if ix > 0 : - if j >= 1: - txt = txt + ' ' - txt = txt + ix_to_word[str(ix.item())] - else: - break - if int(os.getenv('REMOVE_BAD_ENDINGS', '0')): - flag = 0 - words = txt.split(' ') - for j in range(len(words)): - if words[-j-1] not in bad_endings: - flag = -j - break - txt = ' '.join(words[0:len(words)+flag]) - out.append(txt.replace('@@ ', '')) - return out - - -def save_checkpoint(opt, model, infos, optimizer, histories=None, append=''): - if len(append) > 0: - append = '-' + append - # if checkpoint_path doesn't exist - if not os.path.isdir(opt.checkpoint_path): - os.makedirs(opt.checkpoint_path) - checkpoint_path = os.path.join(opt.checkpoint_path, 'model%s.pth' %(append)) - torch.save(model.state_dict(), checkpoint_path) - print("model saved to {}".format(checkpoint_path)) - optimizer_path = os.path.join(opt.checkpoint_path, 'optimizer%s.pth' %(append)) - torch.save(optimizer.state_dict(), optimizer_path) - with open(os.path.join(opt.checkpoint_path, 'infos_'+opt.id+'%s.pkl' %(append)), 'wb') as f: - pickle_dump(infos, f) - if histories: - with open(os.path.join(opt.checkpoint_path, 'histories_'+opt.id+'%s.pkl' %(append)), 'wb') as f: - pickle_dump(histories, f) - - -def set_lr(optimizer, lr): - for group in optimizer.param_groups: - group['lr'] = lr - -def get_lr(optimizer): - for group in optimizer.param_groups: - return group['lr'] - - -def build_optimizer(params, opt): - if opt.optim == 'rmsprop': - return optim.RMSprop(params, opt.learning_rate, opt.optim_alpha, opt.optim_epsilon, weight_decay=opt.weight_decay) - elif opt.optim == 'adagrad': - return optim.Adagrad(params, opt.learning_rate, weight_decay=opt.weight_decay) - elif opt.optim == 'sgd': - return optim.SGD(params, opt.learning_rate, weight_decay=opt.weight_decay) - elif opt.optim == 'sgdm': - return optim.SGD(params, opt.learning_rate, opt.optim_alpha, weight_decay=opt.weight_decay) - elif opt.optim == 'sgdmom': - return optim.SGD(params, opt.learning_rate, opt.optim_alpha, weight_decay=opt.weight_decay, nesterov=True) - elif opt.optim == 'adam': - return optim.Adam(params, opt.learning_rate, (opt.optim_alpha, opt.optim_beta), opt.optim_epsilon, weight_decay=opt.weight_decay) - elif opt.optim == 'adamw': - return optim.AdamW(params, opt.learning_rate, (opt.optim_alpha, opt.optim_beta), opt.optim_epsilon, weight_decay=opt.weight_decay) - else: - raise Exception("bad option opt.optim: {}".format(opt.optim)) - - -def penalty_builder(penalty_config): - if penalty_config == '': - return lambda x,y: y - pen_type, alpha = penalty_config.split('_') - alpha = float(alpha) - if pen_type == 'wu': - return lambda x,y: length_wu(x,y,alpha) - if pen_type == 'avg': - return lambda x,y: length_average(x,y,alpha) - -def length_wu(length, logprobs, alpha=0.): - """ - NMT length re-ranking score from - "Google's Neural Machine Translation System" :cite:`wu2016google`. - """ - - modifier = (((5 + length) ** alpha) / - ((5 + 1) ** alpha)) - return (logprobs / modifier) - -def length_average(length, logprobs, alpha=0.): - """ - Returns the average probability of tokens in a sequence. - """ - return logprobs / length - - -class NoamOpt(object): - "Optim wrapper that implements rate." - def __init__(self, model_size, factor, warmup, optimizer): - self.optimizer = optimizer - self._step = 0 - self.warmup = warmup - self.factor = factor - self.model_size = model_size - self._rate = 0 - - def step(self): - "Update parameters and rate" - self._step += 1 - rate = self.rate() - for p in self.optimizer.param_groups: - p['lr'] = rate - self._rate = rate - self.optimizer.step() - - def rate(self, step = None): - "Implement `lrate` above" - if step is None: - step = self._step - return self.factor * \ - (self.model_size ** (-0.5) * - min(step ** (-0.5), step * self.warmup ** (-1.5))) - - def __getattr__(self, name): - return getattr(self.optimizer, name) - - def state_dict(self): - state_dict = self.optimizer.state_dict() - state_dict['_step'] = self._step - return state_dict - - def load_state_dict(self, state_dict): - if '_step' in state_dict: - self._step = state_dict['_step'] - del state_dict['_step'] - self.optimizer.load_state_dict(state_dict) - -class ReduceLROnPlateau(object): - "Optim wrapper that implements rate." - def __init__(self, optimizer, mode='min', factor=0.1, patience=10, verbose=False, threshold=0.0001, threshold_mode='rel', cooldown=0, min_lr=0, eps=1e-08): - self.scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode, factor, patience, verbose, threshold, threshold_mode, cooldown, min_lr, eps) - self.optimizer = optimizer - self.current_lr = get_lr(optimizer) - - def step(self): - "Update parameters and rate" - self.optimizer.step() - - def scheduler_step(self, val): - self.scheduler.step(val) - self.current_lr = get_lr(self.optimizer) - - def state_dict(self): - return {'current_lr':self.current_lr, - 'scheduler_state_dict': self.scheduler.state_dict(), - 'optimizer_state_dict': self.optimizer.state_dict()} - - def load_state_dict(self, state_dict): - if 'current_lr' not in state_dict: - # it's normal optimizer - self.optimizer.load_state_dict(state_dict) - set_lr(self.optimizer, self.current_lr) # use the lr fromt the option - else: - # it's a schduler - self.current_lr = state_dict['current_lr'] - self.scheduler.load_state_dict(state_dict['scheduler_state_dict']) - self.optimizer.load_state_dict(state_dict['optimizer_state_dict']) - # current_lr is actually useless in this case - - def rate(self, step = None): - "Implement `lrate` above" - if step is None: - step = self._step - return self.factor * \ - (self.model_size ** (-0.5) * - min(step ** (-0.5), step * self.warmup ** (-1.5))) - - def __getattr__(self, name): - return getattr(self.optimizer, name) - -def get_std_opt(model, optim_func='adam', factor=1, warmup=2000): - # return NoamOpt(model.tgt_embed[0].d_model, 2, 4000, - # torch.optim.Adam(model.parameters(), lr=0, betas=(0.9, 0.98), eps=1e-9)) - optim_func = dict(adam=torch.optim.Adam, - adamw=torch.optim.AdamW)[optim_func] - return NoamOpt(model.d_model, factor, warmup, - optim_func(model.parameters(), lr=0, betas=(0.9, 0.98), eps=1e-9)) diff --git a/captioning/utils/opts.py b/captioning/utils/opts.py deleted file mode 100644 index 778e512361727de0939bbd7b014e6eeb716a0c67..0000000000000000000000000000000000000000 --- a/captioning/utils/opts.py +++ /dev/null @@ -1,412 +0,0 @@ -from __future__ import print_function -import argparse - - -def if_use_feat(caption_model): - # Decide if load attention feature according to caption model - if caption_model in ['show_tell', 'all_img', 'fc', 'newfc']: - use_att, use_fc = False, True - elif caption_model == 'language_model': - use_att, use_fc = False, False - elif caption_model in ['updown', 'topdown']: - use_fc, use_att = True, True - else: - use_att, use_fc = True, False - return use_fc, use_att - -import pprint -class Config(object): - def __init__(self, **kwargs): - """Configuration Class: set kwargs as class attributes with setattr""" - for k, v in kwargs.items(): - setattr(self, k, v) - - @property - def config_str(self): - return pprint.pformat(self.__dict__) - - def __repr__(self): - """Pretty-print configurations in alphabetical order""" - config_str = 'Configurations\n' - config_str += self.config_str - return config_str - - -def parse_opt(parse=True, **optional_kwargs): - parser = argparse.ArgumentParser() - # Data input settings - parser.add_argument('--input_json', type=str, default='data/coco.json', - help='path to the json file containing additional info and vocab') - parser.add_argument('--input_fc_dir', type=str, default='data/cocotalk_fc', - help='path to the directory containing the preprocessed fc feats') - parser.add_argument('--input_att_dir', type=str, default='data/cocotalk_att', - help='path to the directory containing the preprocessed att feats') - parser.add_argument('--input_box_dir', type=str, default='data/cocotalk_box', - help='path to the directory containing the boxes of att feats') - parser.add_argument('--input_label_h5', type=str, default='data/coco_label.h5', - help='path to the h5file containing the preprocessed dataset') - parser.add_argument('--data_in_memory', action='store_true', - help='True if we want to save the features in memory') - parser.add_argument('--start_from', type=str, default=None, - help="""continue training from saved model at this path. Path must contain files saved by previous training process: - 'infos.pkl' : configuration; - 'model.pth' : weights - """) - parser.add_argument('--cached_tokens', type=str, default='coco-train-idxs', - help='Cached token file for calculating cider score during self critical training.') - - # Model settings - parser.add_argument('--caption_model', type=str, default="show_tell", - help='show_tell, show_attend_tell, all_img, fc, att2in, att2in2, att2all2, adaatt, adaattmo, updown, stackatt, denseatt, transformer') - parser.add_argument('--rnn_size', type=int, default=512, - help='size of the rnn in number of hidden nodes in each layer') - parser.add_argument('--num_layers', type=int, default=1, - help='number of layers in the RNN') - parser.add_argument('--rnn_type', type=str, default='lstm', - help='rnn, gru, or lstm') - parser.add_argument('--input_encoding_size', type=int, default=512, - help='the encoding size of each token in the vocabulary, and the image.') - parser.add_argument('--att_hid_size', type=int, default=512, - help='the hidden size of the attention MLP; only useful in show_attend_tell; 0 if not using hidden layer') - parser.add_argument('--fc_feat_size', type=int, default=2048, - help='2048 for resnet, 4096 for vgg') - parser.add_argument('--att_feat_size', type=int, default=2048, - help='2048 for resnet, 512 for vgg') - parser.add_argument('--logit_layers', type=int, default=1, - help='number of layers in the RNN') - - - parser.add_argument('--use_bn', type=int, default=0, - help='If 1, then do batch_normalization first in att_embed, if 2 then do bn both in the beginning and the end of att_embed') - - # feature manipulation - parser.add_argument('--norm_att_feat', type=int, default=0, - help='If normalize attention features') - parser.add_argument('--use_box', type=int, default=0, - help='If use box features') - parser.add_argument('--norm_box_feat', type=int, default=0, - help='If use box, do we normalize box feature') - - # Optimization: General - parser.add_argument('--max_epochs', type=int, default=-1, - help='number of epochs') - parser.add_argument('--batch_size', type=int, default=16, - help='minibatch size') - parser.add_argument('--grad_clip_mode', type=str, default='value', - help='value or norm') - parser.add_argument('--grad_clip_value', type=float, default=0.1, - help='clip gradients at this value/max_norm, 0 means no clipping') - parser.add_argument('--drop_prob_lm', type=float, default=0.5, - help='strength of dropout in the Language Model RNN') - parser.add_argument('--self_critical_after', type=int, default=-1, - help='After what epoch do we start finetuning the CNN? (-1 = disable; never finetune, 0 = finetune from start)') - parser.add_argument('--seq_per_img', type=int, default=5, - help='number of captions to sample for each image during training. Done for efficiency since CNN forward pass is expensive. E.g. coco has 5 sents/image') - - parser.add_argument('--verbose', type=int, default=0) - - # Sample related - add_eval_sample_opts(parser) - - #Optimization: for the Language Model - parser.add_argument('--optim', type=str, default='adam', - help='what update to use? rmsprop|sgd|sgdmom|adagrad|adam|adamw') - parser.add_argument('--learning_rate', type=float, default=4e-4, - help='learning rate') - parser.add_argument('--learning_rate_decay_start', type=int, default=-1, - help='at what iteration to start decaying learning rate? (-1 = dont) (in epoch)') - parser.add_argument('--learning_rate_decay_every', type=int, default=3, - help='every how many iterations thereafter to drop LR?(in epoch)') - parser.add_argument('--learning_rate_decay_rate', type=float, default=0.8, - help='every how many iterations thereafter to drop LR?(in epoch)') - parser.add_argument('--optim_alpha', type=float, default=0.9, - help='alpha for adam') - parser.add_argument('--optim_beta', type=float, default=0.999, - help='beta used for adam') - parser.add_argument('--optim_epsilon', type=float, default=1e-8, - help='epsilon that goes into denominator for smoothing') - parser.add_argument('--weight_decay', type=float, default=0, - help='weight_decay') - # Transformer - parser.add_argument('--label_smoothing', type=float, default=0, - help='') - parser.add_argument('--noamopt', action='store_true', - help='') - parser.add_argument('--noamopt_warmup', type=int, default=2000, - help='') - parser.add_argument('--noamopt_factor', type=float, default=1, - help='') - parser.add_argument('--reduce_on_plateau', action='store_true', - help='') - parser.add_argument('--reduce_on_plateau_factor', type=float, default=0.5, - help='') - parser.add_argument('--reduce_on_plateau_patience', type=int, default=3, - help='') - parser.add_argument('--cached_transformer', action='store_true', - help='') - - - parser.add_argument('--use_warmup', action='store_true', - help='warm up the learing rate?') - - parser.add_argument('--scheduled_sampling_start', type=int, default=-1, - help='at what iteration to start decay gt probability') - parser.add_argument('--scheduled_sampling_increase_every', type=int, default=5, - help='every how many iterations thereafter to gt probability') - parser.add_argument('--scheduled_sampling_increase_prob', type=float, default=0.05, - help='How much to update the prob') - parser.add_argument('--scheduled_sampling_max_prob', type=float, default=0.25, - help='Maximum scheduled sampling prob.') - - - # Evaluation/Checkpointing - parser.add_argument('--val_images_use', type=int, default=3200, - help='how many images to use when periodically evaluating the validation loss? (-1 = all)') - parser.add_argument('--save_checkpoint_every', type=int, default=2500, - help='how often to save a model checkpoint (in iterations)?') - parser.add_argument('--save_every_epoch', action='store_true', - help='Save checkpoint every epoch, will overwrite save_checkpoint_every') - parser.add_argument('--save_history_ckpt', type=int, default=0, - help='If save checkpoints at every save point') - parser.add_argument('--checkpoint_path', type=str, default=None, - help='directory to store checkpointed models') - parser.add_argument('--language_eval', type=int, default=0, - help='Evaluate language as well (1 = yes, 0 = no)? BLEU/CIDEr/METEOR/ROUGE_L? requires coco-caption code from Github.') - parser.add_argument('--losses_log_every', type=int, default=25, - help='How often do we snapshot losses, for inclusion in the progress dump? (0 = disable)') - parser.add_argument('--load_best_score', type=int, default=1, - help='Do we load previous best score when resuming training.') - - # misc - parser.add_argument('--id', type=str, default='', - help='an id identifying this run/job. used in cross-val and appended when writing progress files') - parser.add_argument('--train_only', type=int, default=0, - help='if true then use 80k, else use 110k') - - - # Reward - parser.add_argument('--cider_reward_weight', type=float, default=1, - help='The reward weight from cider') - parser.add_argument('--bleu_reward_weight', type=float, default=0, - help='The reward weight from bleu4') - - # Reward - parser.add_argument('--clipscore_reward_weight', type=float, default=1, - help='The reward weight from clipscore') - parser.add_argument('--use_clipscore', type=float, default=0, - help='Use CLIPScore') - parser.add_argument('--clipscore_mode', type=str, default='clip_s', - help='Which CLIPScore to use: clip_s|refclip_s') - - - # Structure_loss - parser.add_argument('--structure_loss_weight', type=float, default=1, - help='') - parser.add_argument('--structure_after', type=int, default=-1, - help='T') - parser.add_argument('--structure_loss_type', type=str, default='seqnll', - help='') - parser.add_argument('--struc_use_logsoftmax', action='store_true', help='') - parser.add_argument('--entropy_reward_weight', type=float, default=0, - help='Entropy reward, seems very interesting') - parser.add_argument('--self_cider_reward_weight', type=float, default=0, - help='self cider reward') - - # Used for self critical or structure. Used when sampling is need during training - parser.add_argument('--train_sample_n', type=int, default=16, - help='The reward weight from cider') - parser.add_argument('--train_sample_method', type=str, default='sample', - help='') - parser.add_argument('--train_beam_size', type=int, default=1, - help='') - - # Used for self critical - parser.add_argument('--sc_sample_method', type=str, default='greedy', - help='') - parser.add_argument('--sc_beam_size', type=int, default=1, - help='') - - - # For diversity evaluation during training - add_diversity_opts(parser) - - - # config - parser.add_argument('--cfg', type=str, default=None, - help='configuration; similar to what is used in detectron') - parser.add_argument( - '--set_cfgs', dest='set_cfgs', - help='Set config keys. Key value sequence seperate by whitespace.' - 'e.g. [key] [value] [key] [value]\n This has higher priority' - 'than cfg file but lower than other args. (You can only overwrite' - 'arguments that have alerady been defined in config file.)', - default=[], nargs='+') - # How will config be used - # 1) read cfg argument, and load the cfg file if it's not None - # 2) Overwrite cfg argument with set_cfgs - # 3) parse config argument to args. - # 4) in the end, parse command line argument and overwrite args - - # step 1: read cfg_fn - # args = parser.parse_args() - # Parse the arguments. - if parse: - args = parser.parse_args() - # For interative engironmnet (ex. jupyter) - else: - args = parser.parse_known_args()[0] - # print(args) - - # Namespace => Dictionary - kwargs = vars(args) - # for k, v in optional_kwargs.items(): - # setattr(args, k, v) - kwargs.update(optional_kwargs) - - args = Config(**kwargs) - - - if args.cfg is not None or args.set_cfgs is not None: - from .config import CfgNode - if args.cfg is not None: - # print('Read Cfg') - cn = CfgNode(CfgNode.load_yaml_with_base(args.cfg)) - # print(cn) - else: - cn = CfgNode() - if args.set_cfgs is not None: - cn.merge_from_list(args.set_cfgs) - for k,v in cn.items(): - if not hasattr(args, k): - import os - if 'LOCAL_RANK' in os.environ and os.environ['LOCAL_RANK'] != '0': - pass - else: - print('Warning: key %s not in args' % k) - - setattr(args, k, v) - - if parse: - args = parser.parse_args(namespace=args) - else: - args = parser.parse_known_args(namespace=args)[0] - - # Check if args are valid - assert args.rnn_size > 0, "rnn_size should be greater than 0" - assert args.num_layers > 0, "num_layers should be greater than 0" - assert args.input_encoding_size > 0, "input_encoding_size should be greater than 0" - assert args.batch_size > 0, "batch_size should be greater than 0" - assert args.drop_prob_lm >= 0 and args.drop_prob_lm < 1, "drop_prob_lm should be between 0 and 1" - assert args.seq_per_img > 0, "seq_per_img should be greater than 0" - assert args.beam_size > 0, "beam_size should be greater than 0" - assert args.save_checkpoint_every > 0, "save_checkpoint_every should be greater than 0" - assert args.losses_log_every > 0, "losses_log_every should be greater than 0" - assert args.language_eval == 0 or args.language_eval == 1, "language_eval should be 0 or 1" - assert args.load_best_score == 0 or args.load_best_score == 1, "language_eval should be 0 or 1" - assert args.train_only == 0 or args.train_only == 1, "language_eval should be 0 or 1" - - # default value for start_from and checkpoint_path - args.checkpoint_path = args.checkpoint_path or './log_%s' %args.id - args.start_from = args.start_from or args.checkpoint_path - - # Deal with feature things before anything - args.use_fc, args.use_att = if_use_feat(args.caption_model) - if args.use_box: args.att_feat_size = args.att_feat_size + 5 - - return args - - -def add_eval_options(parser): - # Basic options - parser.add_argument('--batch_size', type=int, default=0, - help='if > 0 then overrule, otherwise load from checkpoint.') - parser.add_argument('--num_images', type=int, default=-1, - help='how many images to use when periodically evaluating the loss? (-1 = all)') - parser.add_argument('--language_eval', type=int, default=0, - help='Evaluate language as well (1 = yes, 0 = no)? BLEU/CIDEr/METEOR/ROUGE_L? requires coco-caption code from Github.') - parser.add_argument('--dump_images', type=int, default=1, - help='Dump images into vis/imgs folder for vis? (1=yes,0=no)') - parser.add_argument('--dump_json', type=int, default=1, - help='Dump json with predictions into vis folder? (1=yes,0=no)') - parser.add_argument('--dump_path', type=int, default=0, - help='Write image paths along with predictions into vis json? (1=yes,0=no)') - - # Sampling options - add_eval_sample_opts(parser) - - # For evaluation on a folder of images: - parser.add_argument('--image_folder', type=str, default='', - help='If this is nonempty then will predict on the images in this folder path') - parser.add_argument('--image_root', type=str, default='', - help='In case the image paths have to be preprended with a root path to an image folder') - # For evaluation on MSCOCO images from some split: - parser.add_argument('--input_fc_dir', type=str, default='', - help='path to the h5file containing the preprocessed dataset') - parser.add_argument('--input_att_dir', type=str, default='', - help='path to the h5file containing the preprocessed dataset') - parser.add_argument('--input_box_dir', type=str, default='', - help='path to the h5file containing the preprocessed dataset') - parser.add_argument('--input_label_h5', type=str, default='', - help='path to the h5file containing the preprocessed dataset') - parser.add_argument('--input_json', type=str, default='', - help='path to the json file containing additional info and vocab. empty = fetch from model checkpoint.') - parser.add_argument('--split', type=str, default='test', - help='if running on MSCOCO images, which split to use: val|test|train') - parser.add_argument('--coco_json', type=str, default='', - help='if nonempty then use this file in DataLoaderRaw (see docs there). Used only in MSCOCO test evaluation, where we have a specific json file of only test set images.') - # misc - parser.add_argument('--id', type=str, default='', - help='an id identifying this run/job. used only if language_eval = 1 for appending to intermediate files') - parser.add_argument('--verbose_beam', type=int, default=1, - help='if we need to print out all beam search beams.') - parser.add_argument('--verbose_loss', type=int, default=0, - help='If calculate loss using ground truth during evaluation') - -def add_diversity_opts(parser): - parser.add_argument('--sample_n', type=int, default=1, - help='Diverse sampling') - parser.add_argument('--sample_n_method', type=str, default='sample', - help='sample, bs, dbs, gumbel, topk, dgreedy, dsample, dtopk, dtopp') - parser.add_argument('--eval_oracle', type=int, default=1, - help='if we need to calculate loss.') - - -# Sampling related options -def add_eval_sample_opts(parser): - parser.add_argument('--sample_method', type=str, default='greedy', - help='greedy; sample; gumbel; top, top<0-1>') - parser.add_argument('--beam_size', type=int, default=1, - help='used when sample_method = greedy, indicates number of beams in beam search. Usually 2 or 3 works well. More is not better. Set this to 1 for faster runtime but a bit worse performance.') - parser.add_argument('--max_length', type=int, default=20, - help='Maximum length during sampling') - parser.add_argument('--length_penalty', type=str, default='', - help='wu_X or avg_X, X is the alpha') - parser.add_argument('--group_size', type=int, default=1, - help='used for diverse beam search. if group_size is 1, then it\'s normal beam search') - parser.add_argument('--diversity_lambda', type=float, default=0.5, - help='used for diverse beam search. Usually from 0.2 to 0.8. Higher value of lambda produces a more diverse list') - parser.add_argument('--temperature', type=float, default=1.0, - help='temperature when sampling from distributions (i.e. when sample_method = sample). Lower = "safer" predictions.') - parser.add_argument('--decoding_constraint', type=int, default=0, - help='If 1, not allowing same word in a row') - parser.add_argument('--block_trigrams', type=int, default=0, - help='block repeated trigram.') - parser.add_argument('--remove_bad_endings', type=int, default=0, - help='Remove bad endings') - parser.add_argument('--suppress_UNK', type=int, default=1, - help='Not predicting UNK') - - -if __name__ == '__main__': - import sys - sys.argv = [sys.argv[0]] - args = parse_opt() - print(args) - print() - sys.argv = [sys.argv[0], '--cfg', 'configs/updown_long.yml'] - args1 = parse_opt() - print(dict(set(vars(args1).items()) - set(vars(args).items()))) - print() - sys.argv = [sys.argv[0], '--cfg', 'configs/updown_long.yml', '--caption_model', 'att2in2'] - args2 = parse_opt() - print(dict(set(vars(args2).items()) - set(vars(args1).items()))) diff --git a/captioning/utils/resnet.py b/captioning/utils/resnet.py deleted file mode 100644 index e8aaff426d5d6c837f6dc49eefa16a31fc1834de..0000000000000000000000000000000000000000 --- a/captioning/utils/resnet.py +++ /dev/null @@ -1,71 +0,0 @@ -import torch -import torch.nn as nn -import torchvision.models.resnet -from torchvision.models.resnet import BasicBlock, Bottleneck - -class ResNet(torchvision.models.resnet.ResNet): - def __init__(self, block, layers, num_classes=1000): - super(ResNet, self).__init__(block, layers, num_classes) - self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=0, ceil_mode=True) # change - for i in range(2, 5): - getattr(self, 'layer%d'%i)[0].conv1.stride = (2,2) - getattr(self, 'layer%d'%i)[0].conv2.stride = (1,1) - -def resnet18(pretrained=False): - """Constructs a ResNet-18 model. - - Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet - """ - model = ResNet(BasicBlock, [2, 2, 2, 2]) - if pretrained: - model.load_state_dict(model_zoo.load_url(model_urls['resnet18'])) - return model - - -def resnet34(pretrained=False): - """Constructs a ResNet-34 model. - - Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet - """ - model = ResNet(BasicBlock, [3, 4, 6, 3]) - if pretrained: - model.load_state_dict(model_zoo.load_url(model_urls['resnet34'])) - return model - - -def resnet50(pretrained=False): - """Constructs a ResNet-50 model. - - Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet - """ - model = ResNet(Bottleneck, [3, 4, 6, 3]) - if pretrained: - model.load_state_dict(model_zoo.load_url(model_urls['resnet50'])) - return model - - -def resnet101(pretrained=False): - """Constructs a ResNet-101 model. - - Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet - """ - model = ResNet(Bottleneck, [3, 4, 23, 3]) - if pretrained: - model.load_state_dict(model_zoo.load_url(model_urls['resnet101'])) - return model - - -def resnet152(pretrained=False): - """Constructs a ResNet-152 model. - - Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet - """ - model = ResNet(Bottleneck, [3, 8, 36, 3]) - if pretrained: - model.load_state_dict(model_zoo.load_url(model_urls['resnet152'])) - return model \ No newline at end of file diff --git a/captioning/utils/resnet_utils.py b/captioning/utils/resnet_utils.py deleted file mode 100644 index e1df171ab75700352333f6af5d59f751819b57f6..0000000000000000000000000000000000000000 --- a/captioning/utils/resnet_utils.py +++ /dev/null @@ -1,27 +0,0 @@ -import torch -import torch.nn as nn -import torch.nn.functional as F - -class myResnet(nn.Module): - def __init__(self, resnet): - super(myResnet, self).__init__() - self.resnet = resnet - - def forward(self, img, att_size=14): - x = img.unsqueeze(0) - - x = self.resnet.conv1(x) - x = self.resnet.bn1(x) - x = self.resnet.relu(x) - x = self.resnet.maxpool(x) - - x = self.resnet.layer1(x) - x = self.resnet.layer2(x) - x = self.resnet.layer3(x) - x = self.resnet.layer4(x) - - fc = x.mean(3).mean(2).squeeze() - att = F.adaptive_avg_pool2d(x,[att_size,att_size]).squeeze().permute(1, 2, 0) - - return fc, att - diff --git a/captioning/utils/rewards.py b/captioning/utils/rewards.py deleted file mode 100644 index 668b830cbdef05d6c3eab8d99a07918a325e9157..0000000000000000000000000000000000000000 --- a/captioning/utils/rewards.py +++ /dev/null @@ -1,392 +0,0 @@ -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import numpy as np -import time -from collections import OrderedDict -import torch - -import sys -try: - sys.path.append("cider") - from pyciderevalcap.ciderD.ciderD import CiderD - from pyciderevalcap.cider.cider import Cider - sys.path.append("coco-caption") - from pycocoevalcap.bleu.bleu import Bleu -except: - print('cider or coco-caption missing') - -CiderD_scorer = None -Cider_scorer = None -Bleu_scorer = None -#CiderD_scorer = CiderD(df='corpus') - - -from .misc import decode_sequence - -def init_scorer(cached_tokens): - global CiderD_scorer - CiderD_scorer = CiderD_scorer or CiderD(df=cached_tokens) - global Cider_scorer - Cider_scorer = Cider_scorer or Cider(df=cached_tokens) - global Bleu_scorer - Bleu_scorer = Bleu_scorer or Bleu(4) - -def array_to_str(arr): - out = '' - for i in range(len(arr)): - out += str(arr[i]) + ' ' - if arr[i] == 0: - break - return out.strip() - -def get_self_critical_reward(greedy_res, data_gts, gen_result, opt): - batch_size = len(data_gts) - gen_result_size = gen_result.shape[0] - seq_per_img = gen_result_size // len(data_gts) # gen_result_size = batch_size * seq_per_img - assert greedy_res.shape[0] == batch_size - - res = OrderedDict() - gen_result = gen_result.data.cpu().numpy() - greedy_res = greedy_res.data.cpu().numpy() - for i in range(gen_result_size): - res[i] = [array_to_str(gen_result[i])] - for i in range(batch_size): - res[gen_result_size + i] = [array_to_str(greedy_res[i])] - - gts = OrderedDict() - for i in range(len(data_gts)): - gts[i] = [array_to_str(data_gts[i][j]) for j in range(len(data_gts[i]))] - - res_ = [{'image_id':i, 'caption': res[i]} for i in range(len(res))] - res__ = {i: res[i] for i in range(len(res_))} - gts_ = {i: gts[i // seq_per_img] for i in range(gen_result_size)} - gts_.update({i+gen_result_size: gts[i] for i in range(batch_size)}) - if opt.cider_reward_weight > 0: - _, cider_scores = CiderD_scorer.compute_score(gts_, res_) - if hasattr(opt, 'verbose') and not opt.verbose: - pass - else: - print('Cider scores:', _) - else: - cider_scores = 0 - if opt.bleu_reward_weight > 0: - _, bleu_scores = Bleu_scorer.compute_score(gts_, res__) - bleu_scores = np.array(bleu_scores[3]) - if hasattr(opt, 'verbose') and not opt.verbose: - pass - else: - print('Bleu scores:', _[3]) - else: - bleu_scores = 0 - scores = opt.cider_reward_weight * cider_scores + opt.bleu_reward_weight * bleu_scores - - unnormalized_reward_mean = scores[:gen_result_size].flatten().mean() - - scores = scores[:gen_result_size].reshape(batch_size, seq_per_img) - scores[-batch_size:][:, np.newaxis] - - scores = scores.reshape(gen_result_size) - - rewards = np.repeat(scores[:, np.newaxis], gen_result.shape[1], 1) - - return rewards, unnormalized_reward_mean - - -def get_self_critical_clipscore_reward(greedy_res, data_gts, gen_result, opt, clipscore_model, clip_vis_feats, vocab): - batch_size = len(data_gts) - gen_result_size = gen_result.shape[0] - seq_per_img = gen_result_size // len(data_gts) # gen_result_size = batch_size * seq_per_img - assert greedy_res.shape[0] == batch_size - - B = batch_size - K = seq_per_img - L = gen_result.shape[1] - assert gen_result.shape == (B*K , L) - - # res = OrderedDict() - # gen_result = gen_result.data.cpu().numpy() - # greedy_res = greedy_res.data.cpu().numpy() - # for i in range(gen_result_size): - # res[i] = [array_to_str(gen_result[i])] - # for i in range(batch_size): - # res[gen_result_size + i] = [array_to_str(greedy_res[i])] - - # gts = OrderedDict() - # for i in range(len(data_gts)): - # gts[i] = [array_to_str(data_gts[i][j]) for j in range(len(data_gts[i]))] - - # res_ = [{'image_id':i, 'caption': res[i]} for i in range(len(res))] - # res__ = {i: res[i] for i in range(len(res_))} - # gts_ = {i: gts[i // seq_per_img] for i in range(gen_result_size)} - # gts_.update({i+gen_result_size: gts[i] for i in range(batch_size)}) - - # res = [] - # gen_result = gen_result.data.cpu().numpy() - # greedy_res = greedy_res.data.cpu().numpy() - # # for i in range(gen_result_size): - # # res.append(array_to_str(gen_result[i])) - # res.extend(decode_sequence(vocab, gen_result)) - - - # # for i in range(batch_size): - # # res.append(array_to_str(greedy_res[i])) - # res.extend(decode_sequence(vocab, greedy_res)) - - if clipscore_model.mode == 'refclip_s': - gts = [] - gts_valid_mask = [] - max_n_refs = max([len(_gts) for _gts in data_gts]) - for i in range(len(data_gts)): - _gts = decode_sequence(vocab, data_gts[i]) - # pad references - n_ref = len(_gts) - _gts.extend([''] * (max_n_refs - n_ref)) - gts.extend(_gts) - gts_valid_mask.extend([1] * n_ref + [0] * (max_n_refs - n_ref)) - assert len(gts) == B * max_n_refs - assert len(gts_valid_mask) == B * max_n_refs - - # print(gts) - # print(gts_valid_mask) - # exit() - - - # assert len(res) == B * K + B, len(res) - - # print(res) - # exit() - - if opt.clipscore_reward_weight > 0: - with torch.no_grad(): - clipscore_model.eval() - - # 1) calculate reward - gen_result = gen_result.data.cpu().numpy() - res = decode_sequence(vocab, gen_result) - assert len(res) == B * K, len(res) - - # [B * K, dim) - if getattr(opt, 'use_grammar', False) and not getattr(opt, 'joint_out', False): - text_pre_feat = clipscore_model.text_extract(res, proj_norm=False) - - grammar_logit = clipscore_model.grammar_score_head(text_pre_feat.view(-1, 512)) - grammar_prob = torch.softmax(grammar_logit, dim=-1)[:, 1] - grammar_prob = grammar_prob.view(B*K).detach() - - text_feat = clipscore_model.clip_model.text_projection(text_pre_feat) - text_feat = text_feat / text_feat.norm(dim=-1, keepdim=True) - - else: - text_feat = clipscore_model.text_extract(res) - - - assert text_feat.size() == (B * K, 512), text_feat.size() - assert clip_vis_feats.size() == (B, 512), clip_vis_feats.size() - - # [B * K, dim] - vis_feat = clip_vis_feats.view(B, 1, -1).expand(-1, K, -1).contiguous().view(B * K, -1) - - clip_s = clipscore_model(text_feat=text_feat, img_feat=vis_feat, mode='clip_s') - clip_s = clip_s.view(B * K).detach() - - if clipscore_model.mode == 'refclip_s': - # [B * n_ref, dim] - ref_text_feat = clipscore_model.text_extract(gts) - ref_text_mask = torch.tensor(gts_valid_mask, dtype=ref_text_feat.dtype, device=ref_text_feat.device) - - assert ref_text_feat.size() == (B * max_n_refs, 512), ref_text_feat.size() - assert ref_text_mask.size() == (B * max_n_refs,), ref_text_mask.size() - - # [B * K] - refclip_s = clipscore_model.calc_refclip_s( - text_feat=text_feat, img_feat=vis_feat, - ref_text_feat=ref_text_feat.view(B, 1, max_n_refs, -1).expand(-1, K, -1, -1).contiguous().view(B * K * max_n_refs, -1), - ref_text_mask=ref_text_mask.view(B, 1, max_n_refs).expand(-1, K, -1).contiguous().view(B * K * max_n_refs), - clip_s=clip_s) - refclip_s = refclip_s.view(B * K).detach() - - # 2) calcualte reward for baseline (greedy) - greedy_res = greedy_res.data.cpu().numpy() - res = decode_sequence(vocab, greedy_res) - assert len(res) == B, len(res) - - # [B, dim) - - if getattr(opt, 'use_grammar', False) and getattr(opt, 'use_grammar_baseline', False) and not getattr(opt, 'joint_out', False): - text_pre_feat = clipscore_model.text_extract(res, proj_norm=False) - - grammar_logit = clipscore_model.grammar_score_head(text_pre_feat.view(-1, 512)) - grammar_prob_baseline = torch.softmax(grammar_logit, dim=-1)[:, 1] - grammar_prob_baseline = grammar_prob_baseline.view(B).detach() - - text_feat = clipscore_model.clip_model.text_projection(text_pre_feat) - text_feat = text_feat / text_feat.norm(dim=-1, keepdim=True) - else: - text_feat = clipscore_model.text_extract(res) - - assert text_feat.size() == (B, 512), text_feat.size() - assert clip_vis_feats.size() == (B, 512), clip_vis_feats.size() - - vis_feat = clip_vis_feats.view(B, 512) - - # [B] - clip_s_baseline = clipscore_model(text_feat=text_feat, img_feat=vis_feat, mode='clip_s') - clip_s_baseline = clip_s_baseline.view(B).detach() - - if clipscore_model.mode == 'refclip_s': - # # [B * n_ref] - # ref_text_feat = clipscore_model.text_extract(gts) - # ref_text_mask = torch.tensor(gts_valid_mask, dtype=ref_text_feat.dtype, device=ref_text_feat.device) - # assert ref_text_feat.size() == (B * max_n_refs, 512), ref_text_feat.size() - # assert ref_text_mask.size() == (B * max_n_refs), ref_text_mask.size() - - # [B] - refclip_s_baseline = clipscore_model.calc_refclip_s( - text_feat=text_feat, img_feat=vis_feat, - ref_text_feat=ref_text_feat, - ref_text_mask=ref_text_mask, - clip_s=clip_s_baseline) - refclip_s_baseline = refclip_s_baseline.view(B).detach() - - if clipscore_model.mode == 'clip_s': - rewards = clip_s - clip_s_baseline.view(B, 1).expand(-1, K).contiguous().flatten() - unnormalized_mean_reward = clip_s.mean() - elif clipscore_model.mode == 'refclip_s': - rewards = refclip_s - refclip_s_baseline.view(B, 1).expand(-1, K).contiguous().flatten() - unnormalized_mean_reward = refclip_s.mean() - - # # [B * K + B, dim) - # text_feat = clipscore_model.text_extract(res) - # assert text_feat.size() == (B * K + B, 512), text_feat.size() - - # assert clip_vis_feats.size() == (B, 512), clip_vis_feats.size() - - # # [B, dim] -> [B * K + B, dim] - # # vis_feat = clip_vis_feats.view(B, 1, -1).expand(-1, K + 1, -1).contiguous().view(B * (K + 1), -1) - # # vis_feat = clip_vis_feats.view(1, B, -1).expand(K + 1, -1, -1).contiguous().view((K + 1) * B, -1) - - # # [B * K, dim] - # gen_vis_feat = clip_vis_feats.view(B, 1, -1).expand(-1, K, -1).contiguous().view(B * K, -1) - # # [B, dim] - # greedy_vis_feat = clip_vis_feats - # # [B * K + B, dim] - # vis_feat = torch.cat([gen_vis_feat, greedy_vis_feat], dim=0) - - # # if clipscore_model.mode == 'clip_s': - # # [B * K + B, dim] - # clip_s = clipscore_model(text_feat=text_feat, img_feat=vis_feat) - # clip_s = clip_s.view(B * K + B).detach() - - - # if clipscore_model.mode == 'refclip_s': - # # [B * K, dim] - # ref_text_feat = clipscore_model.text_extract(gts) - - # clipscore_scores = clipscore_model.calc_refclip_s(text_feat=text_feat, img_feat=vis_feat, ref_text_feat=ref_text_feat, clip_s=clip_s) - # clipscore_scores = clipscore_scores.view(B * K + B).detach() - - if getattr(opt, 'use_grammar', False) and not getattr(opt, 'joint_out', False): - - if getattr(opt, 'use_grammar_baseline', False): - grammar_rewards = grammar_prob - grammar_prob_baseline.view(B, 1).expand(-1, K).contiguous().flatten() - else: - grammar_rewards = grammar_prob - else: - grammar_rewards = None - - - if hasattr(opt, 'verbose') and not opt.verbose: - pass - else: - if clipscore_model.mode == 'clip_s': - print('CLIP-S:', rewards) - elif clipscore_model.mode == 'refclip_s': - print('RefCLIP-S:', rewards) - else: - rewards = torch.zeros(B, L) - unnormalized_mean_reward = None - grammar_rewards = None - - - rewards = opt.clipscore_reward_weight * rewards - - - # scores = scores[:gen_result_size].reshape(batch_size, seq_per_img) - scores[-batch_size:][:, np.newaxis] - # scores = scores.reshape(gen_result_size) - # rewards = np.repeat(scores[:, np.newaxis], gen_result.shape[1], 1) - - # [B, K] - # scores = scores[:gen_result_size].reshape(B, K) - scores[-B:].unsqueeze(1) - - # [B*K, L] - # rewards = scores.view(-1, 1).expand(-1, L).contiguous() - rewards = rewards.view(-1, 1).expand(-1, L).contiguous() - - if getattr(opt, 'use_grammar', False) and not getattr(opt, 'joint_out', False): - grammar_rewards = grammar_rewards.view(-1, 1).expand(-1, L).contiguous() - - return rewards, unnormalized_mean_reward, grammar_rewards - -def get_scores(data_gts, gen_result, opt): - batch_size = gen_result.size(0)# batch_size = sample_size * seq_per_img - seq_per_img = batch_size // len(data_gts) - - res = OrderedDict() - - gen_result = gen_result.data.cpu().numpy() - for i in range(batch_size): - res[i] = [array_to_str(gen_result[i])] - - gts = OrderedDict() - for i in range(len(data_gts)): - gts[i] = [array_to_str(data_gts[i][j]) for j in range(len(data_gts[i]))] - - res_ = [{'image_id':i, 'caption': res[i]} for i in range(batch_size)] - res__ = {i: res[i] for i in range(batch_size)} - gts = {i: gts[i // seq_per_img] for i in range(batch_size)} - if opt.cider_reward_weight > 0: - _, cider_scores = CiderD_scorer.compute_score(gts, res_) - # print('Cider scores:', _) - if hasattr(opt, 'verbose') and not opt.verbose: - pass - else: - print('Cider scores:', _) - else: - cider_scores = 0 - if opt.bleu_reward_weight > 0: - _, bleu_scores = Bleu_scorer.compute_score(gts, res__) - bleu_scores = np.array(bleu_scores[3]) - # print('Bleu scores:', _[3]) - if hasattr(opt, 'verbose') and not opt.verbose: - pass - else: - print('Bleu scores:', _[3]) - else: - bleu_scores = 0 - - scores = opt.cider_reward_weight * cider_scores + opt.bleu_reward_weight * bleu_scores - - return scores - -def get_self_cider_scores(data_gts, gen_result, opt): - batch_size = gen_result.size(0)# batch_size = sample_size * seq_per_img - seq_per_img = batch_size // len(data_gts) - - res = [] - - gen_result = gen_result.data.cpu().numpy() - for i in range(batch_size): - res.append(array_to_str(gen_result[i])) - - scores = [] - for i in range(len(data_gts)): - tmp = Cider_scorer.my_self_cider([res[i*seq_per_img:(i+1)*seq_per_img]]) - def get_div(eigvals): - eigvals = np.clip(eigvals, 0, None) - return -np.log(np.sqrt(eigvals[-1]) / (np.sqrt(eigvals).sum())) / np.log(len(eigvals)) - scores.append(get_div(np.linalg.eigvalsh(tmp[0]/10))) - - scores = np.array(scores) - - return scores diff --git a/captioning/utils/utils.py b/captioning/utils/utils.py deleted file mode 100644 index 85e12a8a1fcb5be1fa6b8833381b0a7918add5c4..0000000000000000000000000000000000000000 --- a/captioning/utils/utils.py +++ /dev/null @@ -1,138 +0,0 @@ -import re -import numpy as np -import torch -import torch.distributed as dist -import collections -import logging - -def get_area(pos): - """ - Args - pos: [B, N, 4] - (x1, x2, y1, y2) - - Return - area : [B, N] - """ - # [B, N] - height = pos[:, :, 3] - pos[:, :, 2] - width = pos[:, :, 1] - pos[:, :, 0] - area = height * width - return area - -def get_relative_distance(pos): - """ - Args - pos: [B, N, 4] - (x1, x2, y1, y2) - - Return - out : [B, N, N, 4] - """ - # B, N = pos.size()[:-1] - - # [B, N, N, 4] - relative_distance = pos.unsqueeze(1) - pos.unsqueeze(2) - - return relative_distance - - -class LossMeter(object): - def __init__(self, maxlen=100): - """Computes and stores the running average""" - self.vals = collections.deque([], maxlen=maxlen) - - def __len__(self): - return len(self.vals) - - def update(self, new_val): - self.vals.append(new_val) - - @property - def val(self): - return sum(self.vals) / len(self.vals) - - def __repr__(self): - return str(self.val) - - -def count_parameters(model): - return sum(p.numel() for p in model.parameters() if p.requires_grad) - - -def load_state_dict(state_dict_path, loc='cpu'): - state_dict = torch.load(state_dict_path, map_location=loc) - # Change Multi GPU to single GPU - original_keys = list(state_dict.keys()) - for key in original_keys: - if key.startswith("module."): - new_key = key[len("module."):] - state_dict[new_key] = state_dict.pop(key) - return state_dict - - -def set_global_logging_level(level=logging.ERROR, prefices=[""]): - """ - Override logging levels of different modules based on their name as a prefix. - It needs to be invoked after the modules have been loaded so that their loggers have been initialized. - - Args: - - level: desired level. e.g. logging.INFO. Optional. Default is logging.ERROR - - prefices: list of one or more str prefices to match (e.g. ["transformers", "torch"]). Optional. - Default is `[""]` to match all active loggers. - The match is a case-sensitive `module_name.startswith(prefix)` - """ - prefix_re = re.compile(fr'^(?:{ "|".join(prefices) })') - for name in logging.root.manager.loggerDict: - if re.match(prefix_re, name): - logging.getLogger(name).setLevel(level) - - -def get_iou(anchors, gt_boxes): - """ - anchors: (N, 4) torch floattensor - gt_boxes: (K, 4) torch floattensor - overlaps: (N, K) ndarray of overlap between boxes and query_boxes - """ - N = anchors.size(0) - - if gt_boxes.size() == (4,): - gt_boxes = gt_boxes.view(1, 4) - K = gt_boxes.size(0) - - gt_boxes_area = ( - (gt_boxes[:, 2] - gt_boxes[:, 0] + 1) * - (gt_boxes[:, 3] - gt_boxes[:, 1] + 1) - ).view(1, K) - - anchors_area = ( - (anchors[:, 2] - anchors[:, 0] + 1) * - (anchors[:, 3] - anchors[:, 1] + 1) - ).view(N, 1) - - boxes = anchors.view(N, 1, 4).expand(N, K, 4) - query_boxes = gt_boxes.view(1, K, 4).expand(N, K, 4) - - iw = ( - torch.min(boxes[:, :, 2], query_boxes[:, :, 2]) - - torch.max(boxes[:, :, 0], query_boxes[:, :, 0]) - + 1 - ) - iw[iw < 0] = 0 - - ih = ( - torch.min(boxes[:, :, 3], query_boxes[:, :, 3]) - - torch.max(boxes[:, :, 1], query_boxes[:, :, 1]) - + 1 - ) - ih[ih < 0] = 0 - - ua = anchors_area + gt_boxes_area - (iw * ih) - overlaps = iw * ih / ua - - return overlaps - - -def xywh_to_xyxy(boxes): - """Convert [x y w h] box format to [x1 y1 x2 y2] format.""" - return np.hstack((boxes[:, 0:2], boxes[:, 0:2] + boxes[:, 2:4] - 1)) diff --git a/clip/__init__.py b/clip/__init__.py deleted file mode 100644 index dcc5619538c0f7c782508bdbd9587259d805e0d9..0000000000000000000000000000000000000000 --- a/clip/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from .clip import * diff --git a/clip/bpe_simple_vocab_16e6.txt.gz b/clip/bpe_simple_vocab_16e6.txt.gz deleted file mode 100644 index 36a15856e00a06a9fbed8cdd34d2393fea4a3113..0000000000000000000000000000000000000000 --- a/clip/bpe_simple_vocab_16e6.txt.gz +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:924691ac288e54409236115652ad4aa250f48203de50a9e4722a6ecd48d6804a -size 1356917 diff --git a/clip/clip.py b/clip/clip.py deleted file mode 100644 index 76f241b053e3a6da06b1165e73e0d54c5b5356b2..0000000000000000000000000000000000000000 --- a/clip/clip.py +++ /dev/null @@ -1,193 +0,0 @@ -import hashlib -import os -import urllib -import warnings -from typing import Union, List - -import torch -from PIL import Image -from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize -from tqdm import tqdm - -from .model import build_model -from .simple_tokenizer import SimpleTokenizer as _Tokenizer - -__all__ = ["available_models", "load", "tokenize"] -_tokenizer = _Tokenizer() - -_MODELS = { - "RN50": "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt", - "RN101": "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt", - "RN50x4": "https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt", - "ViT-B/32": "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt", -} - - -def _download(url: str, root: str = os.path.expanduser("~/.cache/clip")): - os.makedirs(root, exist_ok=True) - filename = os.path.basename(url) - - expected_sha256 = url.split("/")[-2] - download_target = os.path.join(root, filename) - - if os.path.exists(download_target) and not os.path.isfile(download_target): - raise RuntimeError(f"{download_target} exists and is not a regular file") - - if os.path.isfile(download_target): - if hashlib.sha256(open(download_target, "rb").read()).hexdigest() == expected_sha256: - return download_target - else: - warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file") - - with urllib.request.urlopen(url) as source, open(download_target, "wb") as output: - with tqdm(total=int(source.info().get("Content-Length")), ncols=80, unit='iB', unit_scale=True) as loop: - while True: - buffer = source.read(8192) - if not buffer: - break - - output.write(buffer) - loop.update(len(buffer)) - - if hashlib.sha256(open(download_target, "rb").read()).hexdigest() != expected_sha256: - raise RuntimeError(f"Model has been downloaded but the SHA256 checksum does not not match") - - return download_target - - -def _transform(n_px): - return Compose([ - Resize(n_px, interpolation=Image.BICUBIC), - CenterCrop(n_px), - lambda image: image.convert("RGB"), - ToTensor(), - Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)), - ]) - - -def available_models() -> List[str]: - """Returns the names of available CLIP models""" - return list(_MODELS.keys()) - - -def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", jit=True): - """Load a CLIP model - - Parameters - ---------- - name : str - A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict - - device : Union[str, torch.device] - The device to put the loaded model - - jit : bool - Whether to load the optimized JIT model (default) or more hackable non-JIT model. - - Returns - ------- - model : torch.nn.Module - The CLIP model - - preprocess : Callable[[PIL.Image], torch.Tensor] - A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input - """ - if name in _MODELS: - model_path = _download(_MODELS[name]) - elif os.path.isfile(name): - model_path = name - else: - raise RuntimeError(f"Model {name} not found; available models = {available_models()}") - - try: - # loading JIT archive - model = torch.jit.load(model_path, map_location=device if jit else "cpu").eval() - state_dict = None - except RuntimeError: - # loading saved state dict - if jit: - warnings.warn(f"File {model_path} is not a JIT archive. Loading as a state dict instead") - jit = False - state_dict = torch.load(model_path, map_location="cpu") - - if not jit: - model = build_model(state_dict or model.state_dict()).to(device) - if str(device) == "cpu": - model.float() - return model, _transform(model.visual.input_resolution) - - # patch the device names - device_holder = torch.jit.trace(lambda: torch.ones([]).to(torch.device(device)), example_inputs=[]) - device_node = [n for n in device_holder.graph.findAllNodes("prim::Constant") if "Device" in repr(n)][-1] - - def patch_device(module): - graphs = [module.graph] if hasattr(module, "graph") else [] - if hasattr(module, "forward1"): - graphs.append(module.forward1.graph) - - for graph in graphs: - for node in graph.findAllNodes("prim::Constant"): - if "value" in node.attributeNames() and str(node["value"]).startswith("cuda"): - node.copyAttributes(device_node) - - model.apply(patch_device) - patch_device(model.encode_image) - patch_device(model.encode_text) - - # patch dtype to float32 on CPU - if str(device) == "cpu": - float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[]) - float_input = list(float_holder.graph.findNode("aten::to").inputs())[1] - float_node = float_input.node() - - def patch_float(module): - graphs = [module.graph] if hasattr(module, "graph") else [] - if hasattr(module, "forward1"): - graphs.append(module.forward1.graph) - - for graph in graphs: - for node in graph.findAllNodes("aten::to"): - inputs = list(node.inputs()) - for i in [1, 2]: # dtype can be the second or third argument to aten::to() - if inputs[i].node()["value"] == 5: - inputs[i].node().copyAttributes(float_node) - - model.apply(patch_float) - patch_float(model.encode_image) - patch_float(model.encode_text) - - model.float() - - return model, _transform(model.input_resolution.item()) - - -def tokenize(texts: Union[str, List[str]], context_length: int = 77) -> torch.LongTensor: - """ - Returns the tokenized representation of given input string(s) - - Parameters - ---------- - texts : Union[str, List[str]] - An input string or a list of input strings to tokenize - - context_length : int - The context length to use; all CLIP models use 77 as the context length - - Returns - ------- - A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length] - """ - if isinstance(texts, str): - texts = [texts] - - sot_token = _tokenizer.encoder["<|startoftext|>"] - eot_token = _tokenizer.encoder["<|endoftext|>"] - all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts] - result = torch.zeros(len(all_tokens), context_length, dtype=torch.long) - - for i, tokens in enumerate(all_tokens): - if len(tokens) > context_length: - raise RuntimeError(f"Input {texts[i]} is too long for context length {context_length}") - result[i, :len(tokens)] = torch.tensor(tokens) - - return result diff --git a/clip/model.py b/clip/model.py deleted file mode 100644 index 049391e9816d7faf00bdab95a08b99a99c3c405a..0000000000000000000000000000000000000000 --- a/clip/model.py +++ /dev/null @@ -1,437 +0,0 @@ -from collections import OrderedDict -from typing import Tuple, Union - -import torch -import torch.nn.functional as F -from torch import nn - - -class Bottleneck(nn.Module): - expansion = 4 - - def __init__(self, inplanes, planes, stride=1): - super().__init__() - - # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1 - self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False) - self.bn1 = nn.BatchNorm2d(planes) - - self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False) - self.bn2 = nn.BatchNorm2d(planes) - - self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity() - - self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False) - self.bn3 = nn.BatchNorm2d(planes * self.expansion) - - self.relu = nn.ReLU(inplace=True) - self.downsample = None - self.stride = stride - - if stride > 1 or inplanes != planes * Bottleneck.expansion: - # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1 - self.downsample = nn.Sequential(OrderedDict([ - ("-1", nn.AvgPool2d(stride)), - ("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)), - ("1", nn.BatchNorm2d(planes * self.expansion)) - ])) - - def forward(self, x: torch.Tensor): - identity = x - - out = self.relu(self.bn1(self.conv1(x))) - out = self.relu(self.bn2(self.conv2(out))) - out = self.avgpool(out) - out = self.bn3(self.conv3(out)) - - if self.downsample is not None: - identity = self.downsample(x) - - out += identity - out = self.relu(out) - return out - - -class AttentionPool2d(nn.Module): - def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None): - super().__init__() - self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5) - self.k_proj = nn.Linear(embed_dim, embed_dim) - self.q_proj = nn.Linear(embed_dim, embed_dim) - self.v_proj = nn.Linear(embed_dim, embed_dim) - self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim) - self.num_heads = num_heads - - def forward(self, x): - x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]).permute(2, 0, 1) # NCHW -> (HW)NC - # print(x.shape, self.positional_embedding.shape) - x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC - x = x + self.positional_embedding[0, :, None, :].to(x.dtype) # (HW+1)NC - x, _ = F.multi_head_attention_forward( - query=x, key=x, value=x, - embed_dim_to_check=x.shape[-1], - num_heads=self.num_heads, - q_proj_weight=self.q_proj.weight, - k_proj_weight=self.k_proj.weight, - v_proj_weight=self.v_proj.weight, - in_proj_weight=None, - in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]), - bias_k=None, - bias_v=None, - add_zero_attn=False, - dropout_p=0, - out_proj_weight=torch.ones_like(self.q_proj.weight), - out_proj_bias=torch.zeros_like(self.q_proj.bias), - # out_proj_weight=self.c_proj.weight, - # out_proj_bias=self.c_proj.bias, - use_separate_proj_weight=True, - training=self.training, - need_weights=False - ) - - return x[0] - - -class ModifiedResNet(nn.Module): - """ - A ResNet class that is similar to torchvision's but contains the following changes: - - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool. - - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1 - - The final pooling layer is a QKV attention instead of an average pool - """ - - def __init__(self, layers, output_dim, heads, input_resolution=224, width=64): - super().__init__() - self.output_dim = output_dim - self.input_resolution = input_resolution - - # the 3-layer stem - self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False) - self.bn1 = nn.BatchNorm2d(width // 2) - self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False) - self.bn2 = nn.BatchNorm2d(width // 2) - self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False) - self.bn3 = nn.BatchNorm2d(width) - self.avgpool = nn.AvgPool2d(2) - self.relu = nn.ReLU(inplace=True) - - # residual layers - self._inplanes = width # this is a *mutable* variable used during construction - self.layer1 = self._make_layer(width, layers[0]) - self.layer2 = self._make_layer(width * 2, layers[1], stride=2) - self.layer3 = self._make_layer(width * 4, layers[2], stride=2) - self.layer4 = self._make_layer(width * 8, layers[3], stride=2) - - embed_dim = width * 32 # the ResNet feature dimension - self.attnpool = AttentionPool2d(input_resolution // 32, embed_dim, heads, output_dim) - - def _make_layer(self, planes, blocks, stride=1): - layers = [Bottleneck(self._inplanes, planes, stride)] - - self._inplanes = planes * Bottleneck.expansion - for _ in range(1, blocks): - layers.append(Bottleneck(self._inplanes, planes)) - - return nn.Sequential(*layers) - - def forward(self, x): - def stem(x): - for conv, bn in [(self.conv1, self.bn1), (self.conv2, self.bn2), (self.conv3, self.bn3)]: - x = self.relu(bn(conv(x))) - x = self.avgpool(x) - return x - - x = x.type(self.conv1.weight.dtype) - x = stem(x) - x = self.layer1(x) - x = self.layer2(x) - x = self.layer3(x) - x = self.layer4(x) - # print(x.shape) - # x = self.attnpool(x) - attnpool = self.attnpool(x) - - return (x, attnpool) - - -class LayerNorm(nn.LayerNorm): - """Subclass torch's LayerNorm to handle fp16.""" - - def forward(self, x: torch.Tensor): - orig_type = x.dtype - ret = super().forward(x.type(torch.float32)) - return ret.type(orig_type) - - -class QuickGELU(nn.Module): - def forward(self, x: torch.Tensor): - return x * torch.sigmoid(1.702 * x) - - -class ResidualAttentionBlock(nn.Module): - def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None): - super().__init__() - - self.attn = nn.MultiheadAttention(d_model, n_head) - self.ln_1 = LayerNorm(d_model) - self.mlp = nn.Sequential(OrderedDict([ - ("c_fc", nn.Linear(d_model, d_model * 4)), - ("gelu", QuickGELU()), - ("c_proj", nn.Linear(d_model * 4, d_model)) - ])) - self.ln_2 = LayerNorm(d_model) - self.attn_mask = attn_mask - - def attention(self, x: torch.Tensor): - self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None - return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0] - - def forward(self, x: torch.Tensor): - x = x + self.attention(self.ln_1(x)) - x = x + self.mlp(self.ln_2(x)) - return x - - -class Transformer(nn.Module): - def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None): - super().__init__() - self.width = width - self.layers = layers - self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask) for _ in range(layers)]) - - def forward(self, x: torch.Tensor): - return self.resblocks(x) - - -class VisualTransformer(nn.Module): - def __init__(self, input_resolution: int, patch_size: int, width: int, layers: int, heads: int, output_dim: int): - super().__init__() - self.input_resolution = input_resolution - self.output_dim = output_dim - self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False) - - scale = width ** -0.5 - self.class_embedding = nn.Parameter(scale * torch.randn(width)) - self.positional_embedding = nn.Parameter(scale * torch.randn((input_resolution // patch_size) ** 2 + 1, width)) - self.ln_pre = LayerNorm(width) - - self.transformer = Transformer(width, layers, heads) - - self.ln_post = LayerNorm(width) - self.proj = nn.Parameter(scale * torch.randn(width, output_dim)) - - def forward(self, x: torch.Tensor): - x = self.conv1(x) # shape = [*, width, grid, grid] - x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2] - x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] - x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width] - x = x + self.positional_embedding.to(x.dtype) - x = self.ln_pre(x) - - x = x.permute(1, 0, 2) # NLD -> LND - x = self.transformer(x) - x = x.permute(1, 0, 2) # LND -> NLD - - # x = self.ln_post(x[:, 0, :]) - - x = self.ln_post(x) - # if self.proj is not None: - # x = x @ self.proj - - return x - - -class CLIP(nn.Module): - def __init__(self, - embed_dim: int, - # vision - image_resolution: int, - vision_layers: Union[Tuple[int, int, int, int], int], - vision_width: int, - vision_patch_size: int, - # text - context_length: int, - vocab_size: int, - transformer_width: int, - transformer_heads: int, - transformer_layers: int - ): - super().__init__() - - self.context_length = context_length - - if isinstance(vision_layers, (tuple, list)): - vision_heads = vision_width * 32 // 64 - self.visual = ModifiedResNet( - layers=vision_layers, - output_dim=embed_dim, - heads=vision_heads, - input_resolution=image_resolution, - width=vision_width - ) - else: - vision_heads = vision_width // 64 - self.visual = VisualTransformer( - input_resolution=image_resolution, - patch_size=vision_patch_size, - width=vision_width, - layers=vision_layers, - heads=vision_heads, - output_dim=embed_dim - ) - - self.transformer = Transformer( - width=transformer_width, - layers=transformer_layers, - heads=transformer_heads, - attn_mask=self.build_attention_mask() - ) - - self.vocab_size = vocab_size - self.token_embedding = nn.Embedding(vocab_size, transformer_width) - self.positional_embedding = nn.Parameter(torch.empty(self.context_length, transformer_width)) - self.ln_final = LayerNorm(transformer_width) - - self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim)) - self.logit_scale = nn.Parameter(torch.ones([])) - - self.initialize_parameters() - - def initialize_parameters(self): - nn.init.normal_(self.token_embedding.weight, std=0.02) - nn.init.normal_(self.positional_embedding, std=0.01) - - if isinstance(self.visual, ModifiedResNet): - if self.visual.attnpool is not None: - std = self.visual.attnpool.c_proj.in_features ** -0.5 - nn.init.normal_(self.visual.attnpool.q_proj.weight, std=std) - nn.init.normal_(self.visual.attnpool.k_proj.weight, std=std) - nn.init.normal_(self.visual.attnpool.v_proj.weight, std=std) - nn.init.normal_(self.visual.attnpool.c_proj.weight, std=std) - - for resnet_block in [self.visual.layer1, self.visual.layer2, self.visual.layer3, self.visual.layer4]: - for name, param in resnet_block.named_parameters(): - if name.endswith("bn3.weight"): - nn.init.zeros_(param) - - proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5) - attn_std = self.transformer.width ** -0.5 - fc_std = (2 * self.transformer.width) ** -0.5 - for block in self.transformer.resblocks: - nn.init.normal_(block.attn.in_proj_weight, std=attn_std) - nn.init.normal_(block.attn.out_proj.weight, std=proj_std) - nn.init.normal_(block.mlp.c_fc.weight, std=fc_std) - nn.init.normal_(block.mlp.c_proj.weight, std=proj_std) - - if self.text_projection is not None: - nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5) - - def build_attention_mask(self): - # lazily create causal attention mask, with full attention between the vision tokens - # pytorch uses additive attention mask; fill with -inf - mask = torch.empty(self.context_length, self.context_length) - mask.fill_(float("-inf")) - mask.triu_(1) # zero out the lower diagonal - return mask - - @property - def dtype(self): - return self.visual.conv1.weight.dtype - - def encode_image(self, image): - return self.visual(image.type(self.dtype)) - - def encode_text(self, text): - x = self.token_embedding(text).type(self.dtype) # [batch_size, n_ctx, d_model] - - x = x + self.positional_embedding.type(self.dtype) - x = x.permute(1, 0, 2) # NLD -> LND - x = self.transformer(x) - x = x.permute(1, 0, 2) # LND -> NLD - x = self.ln_final(x).type(self.dtype) - - # x.shape = [batch_size, n_ctx, transformer.width] - # take features from the eot embedding (eot_token is the highest number in each sequence) - x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection - - return x - - def forward(self, image, text): - image_features = self.encode_image(image) - text_features = self.encode_text(text) - - # normalized features - image_features = image_features / image_features.norm(dim=-1, keepdim=True) - text_features = text_features / text_features.norm(dim=-1, keepdim=True) - - # cosine similarity as logits - logit_scale = self.logit_scale.exp() - logits_per_image = logit_scale * image_features @ text_features.t() - logits_per_text = logit_scale * text_features @ image_features.t() - - # shape = [global_batch_size, global_batch_size] - return logits_per_image, logits_per_text - - -def convert_weights(model: nn.Module): - """Convert applicable model parameters to fp16""" - - def _convert_weights_to_fp16(l): - if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)): - l.weight.data = l.weight.data.half() - if l.bias is not None: - l.bias.data = l.bias.data.half() - - if isinstance(l, nn.MultiheadAttention): - for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]: - tensor = getattr(l, attr) - if tensor is not None: - tensor.data = tensor.data.half() - - for name in ["text_projection", "proj"]: - if hasattr(l, name): - attr = getattr(l, name) - if attr is not None: - attr.data = attr.data.half() - - model.apply(_convert_weights_to_fp16) - - -def build_model(state_dict: dict): - vit = "visual.proj" in state_dict - - if vit: - vision_width = state_dict["visual.conv1.weight"].shape[0] - vision_layers = len([k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")]) - vision_patch_size = state_dict["visual.conv1.weight"].shape[-1] - grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5) - image_resolution = vision_patch_size * grid_size - else: - counts: list = [len(set(k.split(".")[2] for k in state_dict if k.startswith(f"visual.layer{b}"))) for b in [1, 2, 3, 4]] - vision_layers = tuple(counts) - vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0] - output_width = round((state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5) - vision_patch_size = None - assert output_width ** 2 + 1 == state_dict["visual.attnpool.positional_embedding"].shape[0] - image_resolution = output_width * 32 - - embed_dim = state_dict["text_projection"].shape[1] - context_length = state_dict["positional_embedding"].shape[0] - vocab_size = state_dict["token_embedding.weight"].shape[0] - transformer_width = state_dict["ln_final.weight"].shape[0] - transformer_heads = transformer_width // 64 - transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith(f"transformer.resblocks"))) - - model = CLIP( - embed_dim, - image_resolution, vision_layers, vision_width, vision_patch_size, - context_length, vocab_size, transformer_width, transformer_heads, transformer_layers - ) - - for key in ["input_resolution", "context_length", "vocab_size"]: - if key in state_dict: - del state_dict[key] - - convert_weights(model) - model.load_state_dict(state_dict) - return model.eval() diff --git a/clip/simple_tokenizer.py b/clip/simple_tokenizer.py deleted file mode 100644 index 0a66286b7d5019c6e221932a813768038f839c91..0000000000000000000000000000000000000000 --- a/clip/simple_tokenizer.py +++ /dev/null @@ -1,132 +0,0 @@ -import gzip -import html -import os -from functools import lru_cache - -import ftfy -import regex as re - - -@lru_cache() -def default_bpe(): - return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz") - - -@lru_cache() -def bytes_to_unicode(): - """ - Returns list of utf-8 byte and a corresponding list of unicode strings. - The reversible bpe codes work on unicode strings. - This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. - When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. - This is a signficant percentage of your normal, say, 32K bpe vocab. - To avoid that, we want lookup tables between utf-8 bytes and unicode strings. - And avoids mapping to whitespace/control characters the bpe code barfs on. - """ - bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1)) - cs = bs[:] - n = 0 - for b in range(2**8): - if b not in bs: - bs.append(b) - cs.append(2**8+n) - n += 1 - cs = [chr(n) for n in cs] - return dict(zip(bs, cs)) - - -def get_pairs(word): - """Return set of symbol pairs in a word. - Word is represented as tuple of symbols (symbols being variable-length strings). - """ - pairs = set() - prev_char = word[0] - for char in word[1:]: - pairs.add((prev_char, char)) - prev_char = char - return pairs - - -def basic_clean(text): - text = ftfy.fix_text(text) - text = html.unescape(html.unescape(text)) - return text.strip() - - -def whitespace_clean(text): - text = re.sub(r'\s+', ' ', text) - text = text.strip() - return text - - -class SimpleTokenizer(object): - def __init__(self, bpe_path: str = default_bpe()): - self.byte_encoder = bytes_to_unicode() - self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} - merges = gzip.open(bpe_path).read().decode("utf-8").split('\n') - merges = merges[1:49152-256-2+1] - merges = [tuple(merge.split()) for merge in merges] - vocab = list(bytes_to_unicode().values()) - vocab = vocab + [v+'' for v in vocab] - for merge in merges: - vocab.append(''.join(merge)) - vocab.extend(['<|startoftext|>', '<|endoftext|>']) - self.encoder = dict(zip(vocab, range(len(vocab)))) - self.decoder = {v: k for k, v in self.encoder.items()} - self.bpe_ranks = dict(zip(merges, range(len(merges)))) - self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'} - self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE) - - def bpe(self, token): - if token in self.cache: - return self.cache[token] - word = tuple(token[:-1]) + ( token[-1] + '',) - pairs = get_pairs(word) - - if not pairs: - return token+'' - - while True: - bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf'))) - if bigram not in self.bpe_ranks: - break - first, second = bigram - new_word = [] - i = 0 - while i < len(word): - try: - j = word.index(first, i) - new_word.extend(word[i:j]) - i = j - except: - new_word.extend(word[i:]) - break - - if word[i] == first and i < len(word)-1 and word[i+1] == second: - new_word.append(first+second) - i += 2 - else: - new_word.append(word[i]) - i += 1 - new_word = tuple(new_word) - word = new_word - if len(word) == 1: - break - else: - pairs = get_pairs(word) - word = ' '.join(word) - self.cache[token] = word - return word - - def encode(self, text): - bpe_tokens = [] - text = whitespace_clean(basic_clean(text)).lower() - for token in re.findall(self.pat, text): - token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8')) - bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' ')) - return bpe_tokens - - def decode(self, tokens): - text = ''.join([self.decoder[token] for token in tokens]) - text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('', ' ') - return text diff --git a/configs/phase1/FineCapEval_clipRN50_mle.yml b/configs/phase1/FineCapEval_clipRN50_mle.yml deleted file mode 100644 index 0f71ae39417dbd8f1afc25ccd78689c04b746ad3..0000000000000000000000000000000000000000 --- a/configs/phase1/FineCapEval_clipRN50_mle.yml +++ /dev/null @@ -1,60 +0,0 @@ -caption_model: transformer -noamopt: true -noamopt_warmup: 20000 -label_smoothing: 0.0 -input_json: data/FineCapEval.json -input_label_h5: none -input_fc_dir: data/FineCapEval_clip_RN50_fc -input_att_dir: data/FineCapEval_clip_RN50_att -input_clipscore_vis_dir: data/FineCapEval_clipscore_vis - -seq_per_img: 5 -batch_size: 200 -learning_rate: 0.0005 - -checkpoint_path: ./save/clipRN50_mle/clipRN50_mle - -# clip_load_path: '/scratch-space/retrieval/save/clip_negative_text/clip_negative_text-epoch=10.ckpt' - -# Notice: because I'm to lazy, I reuse the option name for RNNs to set the hyperparameters for transformer: -# N=num_layers -# d_model=input_encoding_size -# d_ff=rnn_size - -# will be ignored -num_layers: 6 -input_encoding_size: 512 -rnn_size: 2048 - -# Transformer config -N_enc: 6 -N_dec: 6 -d_model: 512 -d_ff: 2048 -num_att_heads: 8 -dropout: 0.1 - - -learning_rate_decay_start: 0 -scheduled_sampling_start: -1 -save_checkpoint_every: 3000 -language_eval: 1 -val_images_use: 5000 -max_epochs: 15 -train_sample_n: 5 - -REFORWARD: false - -# _BASE_: transformer.yml -reduce_on_plateau: false -noamopt: false -learning_rate: 0.000005 -learning_rate_decay_start: -1 - -self_critical_after: 15 -max_epochs: 50 - -verbose: false -precision: 32 - -use_clipscore: false \ No newline at end of file diff --git a/configs/phase1/clipRN50_mle.yml b/configs/phase1/clipRN50_mle.yml deleted file mode 100644 index 4756d12c6156724db6f9e7025b28276b86125c5e..0000000000000000000000000000000000000000 --- a/configs/phase1/clipRN50_mle.yml +++ /dev/null @@ -1,52 +0,0 @@ -caption_model: transformer -noamopt: true -# noamopt: false -noamopt_warmup: 20000 -label_smoothing: 0.0 -input_json: data/cocotalk.json -input_label_h5: data/cocotalk_label.h5 -input_fc_dir: data/cocotalk_clip_RN50_fc -input_att_dir: data/cocotalk_clip_RN50_att -input_clipscore_vis_dir: data/cocotalk_clipscore_vis -seq_per_img: 5 -# batch_size: 600 -batch_size: 200 - -learning_rate: 0.0005 - -# checkpoint_path: ./save/trans_clip_rn50_sc_pl -checkpoint_path: save/clipRN50_mle/clipRN50_mle - -# Notice: because I'm to lazy, I reuse the option name for RNNs to set the hyperparameters for transformer: -# N=num_layers -# d_model=input_encoding_size -# d_ff=rnn_size - -# will be ignored -num_layers: 6 -input_encoding_size: 512 -rnn_size: 2048 - -# Transformer config -N_enc: 6 -N_dec: 6 -d_model: 512 -d_ff: 2048 -num_att_heads: 8 -dropout: 0.1 - - -learning_rate_decay_start: 0 -scheduled_sampling_start: -1 -save_checkpoint_every: 3000 -language_eval: 1 -val_images_use: 5000 -# max_epochs: 15 -max_epochs: 25 -train_sample_n: 5 - -REFORWARD: false - - -verbose: false -precision: 16 \ No newline at end of file diff --git a/configs/phase1/transformer.yml b/configs/phase1/transformer.yml deleted file mode 100644 index 3dfa9f78b14a8fbec12a4d1177fa489942f861c7..0000000000000000000000000000000000000000 --- a/configs/phase1/transformer.yml +++ /dev/null @@ -1,41 +0,0 @@ -caption_model: transformer -noamopt: true -noamopt_warmup: 20000 -label_smoothing: 0.0 -input_json: data/cocotalk.json -input_label_h5: data/cocotalk_label.h5 -input_att_dir: data/cocotalk_att -seq_per_img: 5 -batch_size: 10 -learning_rate: 0.0005 - -checkpoint_path: ./save/trans_rn50_sc - -# Notice: because I'm to lazy, I reuse the option name for RNNs to set the hyperparameters for transformer: -# N=num_layers -# d_model=input_encoding_size -# d_ff=rnn_size - -# will be ignored -num_layers: 6 -input_encoding_size: 512 -rnn_size: 2048 - -# Transformer config -N_enc: 6 -N_dec: 6 -d_model: 512 -d_ff: 2048 -num_att_heads: 8 -dropout: 0.1 - - -learning_rate_decay_start: 0 -scheduled_sampling_start: -1 -save_checkpoint_every: 3000 -language_eval: 1 -val_images_use: 5000 -max_epochs: 15 -train_sample_n: 5 - -REFORWARD: false \ No newline at end of file diff --git a/configs/phase2/FineCapEval_clipRN50_cider.yml b/configs/phase2/FineCapEval_clipRN50_cider.yml deleted file mode 100644 index 52cac145b854455e92d6ade17be017317907a76a..0000000000000000000000000000000000000000 --- a/configs/phase2/FineCapEval_clipRN50_cider.yml +++ /dev/null @@ -1,61 +0,0 @@ -caption_model: transformer -noamopt: true -noamopt_warmup: 20000 -label_smoothing: 0.0 -input_json: data/FineCapEval.json -input_label_h5: none -input_fc_dir: data/FineCapEval_clip_RN50_fc -input_att_dir: data/FineCapEval_clip_RN50_att -input_clipscore_vis_dir: data/FineCapEval_clipscore_vis - -seq_per_img: 5 -batch_size: 200 -learning_rate: 0.0005 - -checkpoint_path: ./save/clipRN50_cider/clipRN50_cider - -# clip_load_path: '/scratch-space/retrieval/save/clip_negative_text/clip_negative_text-epoch=10.ckpt' - -# Notice: because I'm to lazy, I reuse the option name for RNNs to set the hyperparameters for transformer: -# N=num_layers -# d_model=input_encoding_size -# d_ff=rnn_size - -# will be ignored -num_layers: 6 -input_encoding_size: 512 -rnn_size: 2048 - -# Transformer config -N_enc: 6 -N_dec: 6 -d_model: 512 -d_ff: 2048 -num_att_heads: 8 -dropout: 0.1 - - -learning_rate_decay_start: 0 -scheduled_sampling_start: -1 -save_checkpoint_every: 3000 -language_eval: 1 -val_images_use: 5000 -max_epochs: 15 -train_sample_n: 5 - -REFORWARD: false - -# _BASE_: transformer.yml -reduce_on_plateau: false -noamopt: false -learning_rate: 0.000005 -learning_rate_decay_start: -1 - -self_critical_after: 15 -max_epochs: 50 - -verbose: false -precision: 32 - -# use_clipscore: true -use_clipscore: false \ No newline at end of file diff --git a/configs/phase2/FineCapEval_clipRN50_cider_clips.yml b/configs/phase2/FineCapEval_clipRN50_cider_clips.yml deleted file mode 100644 index a74ee8b6d71e3bd260713f77be5ab9d4c8f4ad5d..0000000000000000000000000000000000000000 --- a/configs/phase2/FineCapEval_clipRN50_cider_clips.yml +++ /dev/null @@ -1,65 +0,0 @@ -caption_model: transformer -noamopt: true -noamopt_warmup: 20000 -label_smoothing: 0.0 -input_json: data/FineCapEval.json -input_label_h5: none -input_fc_dir: data/FineCapEval_clip_RN50_fc -input_att_dir: data/FineCapEval_clip_RN50_att -input_clipscore_vis_dir: data/FineCapEval_clipscore_vis - -seq_per_img: 5 -batch_size: 200 -learning_rate: 0.0005 - -checkpoint_path: ./save/clipRN50_cider_clips/clipRN50_cider_clips - -# clip_load_path: '/scratch-space/retrieval/save/clip_negative_text/clip_negative_text-epoch=10.ckpt' - -# Notice: because I'm to lazy, I reuse the option name for RNNs to set the hyperparameters for transformer: -# N=num_layers -# d_model=input_encoding_size -# d_ff=rnn_size - -# will be ignored -num_layers: 6 -input_encoding_size: 512 -rnn_size: 2048 - -# Transformer config -N_enc: 6 -N_dec: 6 -d_model: 512 -d_ff: 2048 -num_att_heads: 8 -dropout: 0.1 - - -learning_rate_decay_start: 0 -scheduled_sampling_start: -1 -save_checkpoint_every: 3000 -language_eval: 1 -val_images_use: 5000 -max_epochs: 15 -train_sample_n: 5 - -REFORWARD: false - -# _BASE_: transformer.yml -reduce_on_plateau: false -noamopt: false -learning_rate: 0.000005 -learning_rate_decay_start: -1 - -self_critical_after: 15 -max_epochs: 50 - -verbose: false -precision: 32 - -# use_clipscore: true -use_clipscore: false -clipscore_reward_weight: 2.0 -clipscore_mode: clip_s - -use_multi_rewards: true \ No newline at end of file diff --git a/configs/phase2/FineCapEval_clipRN50_clips.yml b/configs/phase2/FineCapEval_clipRN50_clips.yml deleted file mode 100644 index 5440a45f3196995e2ccfb6e61f88a149fee72b2f..0000000000000000000000000000000000000000 --- a/configs/phase2/FineCapEval_clipRN50_clips.yml +++ /dev/null @@ -1,64 +0,0 @@ -caption_model: transformer -noamopt: true -noamopt_warmup: 20000 -label_smoothing: 0.0 -input_json: data/FineCapEval.json -input_label_h5: none -input_fc_dir: data/FineCapEval_clip_RN50_fc -input_att_dir: data/FineCapEval_clip_RN50_att -input_clipscore_vis_dir: data/FineCapEval_clipscore_vis -seq_per_img: 5 -batch_size: 160 -learning_rate: 0.0005 - -checkpoint_path: ./save/clipRN50_clips/clipRN50_clips - -use_multi_rewards: false -use_grammar: false -use_grammar_baseline: false -# clip_load_path: '/scratch-space/retrieval/save/clip_negative_text/clip_negative_text-epoch=10.ckpt' - -# Notice: because I'm to lazy, I reuse the option name for RNNs to set the hyperparameters for transformer: -# N=num_layers -# d_model=input_encoding_size -# d_ff=rnn_size - -# will be ignored -num_layers: 6 -input_encoding_size: 512 -rnn_size: 2048 - -# Transformer config -N_enc: 6 -N_dec: 6 -d_model: 512 -d_ff: 2048 -num_att_heads: 8 -dropout: 0.1 - - -learning_rate_decay_start: 0 -scheduled_sampling_start: -1 -save_checkpoint_every: 3000 -language_eval: 0 -val_images_use: 5000 -max_epochs: 15 -train_sample_n: 5 - -REFORWARD: false - -# _BASE_: transformer.yml -reduce_on_plateau: false -noamopt: false -learning_rate: 0.000005 -learning_rate_decay_start: -1 - -self_critical_after: 15 -max_epochs: 50 - -verbose: false -precision: 32 - -# use_clipscore: true -use_clipscore: false -clipscore_reward_weight: 2.0 \ No newline at end of file diff --git a/configs/phase2/FineCapEval_clipRN50_clips_grammar.yml b/configs/phase2/FineCapEval_clipRN50_clips_grammar.yml deleted file mode 100644 index 854394e9125a81c7351c555dc598eb541eaf20d3..0000000000000000000000000000000000000000 --- a/configs/phase2/FineCapEval_clipRN50_clips_grammar.yml +++ /dev/null @@ -1,64 +0,0 @@ -caption_model: transformer -noamopt: true -noamopt_warmup: 20000 -label_smoothing: 0.0 -input_json: data/FineCapEval.json -input_label_h5: none -input_fc_dir: data/FineCapEval_clip_RN50_fc -input_att_dir: data/FineCapEval_clip_RN50_att -input_clipscore_vis_dir: data/FineCapEval_clipscore_vis -seq_per_img: 5 -batch_size: 160 -learning_rate: 0.0005 - -checkpoint_path: ./save/clipRN50_clips_grammar/clipRN50_clips_grammar - -use_multi_rewards: true -use_grammar: true -use_grammar_baseline: true -# clip_load_path: '/scratch-space/retrieval/save/clip_negative_text/clip_negative_text-epoch=10.ckpt' - -# Notice: because I'm to lazy, I reuse the option name for RNNs to set the hyperparameters for transformer: -# N=num_layers -# d_model=input_encoding_size -# d_ff=rnn_size - -# will be ignored -num_layers: 6 -input_encoding_size: 512 -rnn_size: 2048 - -# Transformer config -N_enc: 6 -N_dec: 6 -d_model: 512 -d_ff: 2048 -num_att_heads: 8 -dropout: 0.1 - - -learning_rate_decay_start: 0 -scheduled_sampling_start: -1 -save_checkpoint_every: 3000 -language_eval: 0 -val_images_use: 5000 -max_epochs: 15 -train_sample_n: 5 - -REFORWARD: false - -# _BASE_: transformer.yml -reduce_on_plateau: false -noamopt: false -learning_rate: 0.000005 -learning_rate_decay_start: -1 - -self_critical_after: 15 -max_epochs: 50 - -verbose: false -precision: 32 - -# use_clipscore: true -use_clipscore: false -clipscore_reward_weight: 2.0 \ No newline at end of file diff --git a/configs/phase2/clipRN50_cider.yml b/configs/phase2/clipRN50_cider.yml deleted file mode 100644 index 924b2dacecf012f158502136169b0340d37e9a47..0000000000000000000000000000000000000000 --- a/configs/phase2/clipRN50_cider.yml +++ /dev/null @@ -1,58 +0,0 @@ -caption_model: transformer -noamopt: true -noamopt_warmup: 20000 -label_smoothing: 0.0 -input_json: data/cocotalk.json -input_label_h5: data/cocotalk_label.h5 -input_fc_dir: data/cocotalk_clip_RN50_fc -input_att_dir: data/cocotalk_clip_RN50_att -# used only for evaluation -input_clipscore_vis_dir: data/cocotalk_clipscore_vis - -seq_per_img: 5 -batch_size: 200 -learning_rate: 0.0005 - -# checkpoint_path: ./save/trans_clip_rn50_sc_pl_scst_cider -checkpoint_path: save/clipRN50_cider/clipRN50_cider - -# Notice: because I'm to lazy, I reuse the option name for RNNs to set the hyperparameters for transformer: -# N=num_layers -# d_model=input_encoding_size -# d_ff=rnn_size - -# will be ignored -num_layers: 6 -input_encoding_size: 512 -rnn_size: 2048 - -# Transformer config -N_enc: 6 -N_dec: 6 -d_model: 512 -d_ff: 2048 -num_att_heads: 8 -dropout: 0.1 - - -learning_rate_decay_start: 0 -scheduled_sampling_start: -1 -save_checkpoint_every: 3000 -language_eval: 1 -val_images_use: 5000 -max_epochs: 15 -train_sample_n: 5 - -REFORWARD: false - -# _BASE_: transformer.yml -reduce_on_plateau: false -noamopt: false -learning_rate: 0.000005 -learning_rate_decay_start: -1 - -self_critical_after: 15 -max_epochs: 40 - -verbose: false -precision: 32 \ No newline at end of file diff --git a/configs/phase2/clipRN50_cider_clips.yml b/configs/phase2/clipRN50_cider_clips.yml deleted file mode 100644 index d1b0f3ff7ce92d80fcb1f77b769cfadec471bc45..0000000000000000000000000000000000000000 --- a/configs/phase2/clipRN50_cider_clips.yml +++ /dev/null @@ -1,61 +0,0 @@ -caption_model: transformer -noamopt: true -noamopt_warmup: 20000 -label_smoothing: 0.0 -input_json: data/cocotalk.json -input_label_h5: data/cocotalk_label.h5 -input_fc_dir: data/cocotalk_clip_RN50_fc -input_att_dir: data/cocotalk_clip_RN50_att -input_clipscore_vis_dir: data/cocotalk_clipscore_vis -seq_per_img: 5 -batch_size: 160 -learning_rate: 0.0005 - -checkpoint_path: save/clipRN50_cider_clips/clipRN50_cider_clips - -# Notice: because I'm to lazy, I reuse the option name for RNNs to set the hyperparameters for transformer: -# N=num_layers -# d_model=input_encoding_size -# d_ff=rnn_size - -# will be ignored -num_layers: 6 -input_encoding_size: 512 -rnn_size: 2048 - -# Transformer config -N_enc: 6 -N_dec: 6 -d_model: 512 -d_ff: 2048 -num_att_heads: 8 -dropout: 0.1 - - -learning_rate_decay_start: 0 -scheduled_sampling_start: -1 -save_checkpoint_every: 3000 -language_eval: 1 -val_images_use: 5000 -max_epochs: 15 -train_sample_n: 5 - -REFORWARD: false - -# _BASE_: transformer.yml -reduce_on_plateau: false -noamopt: false -learning_rate: 0.000005 -learning_rate_decay_start: -1 - -self_critical_after: 15 -max_epochs: 40 - -verbose: false -precision: 32 - -use_clipscore: true -clipscore_reward_weight: 2.0 -clipscore_mode: clip_s - -use_multi_rewards: true \ No newline at end of file diff --git a/configs/phase2/clipRN50_clips.yml b/configs/phase2/clipRN50_clips.yml deleted file mode 100644 index 2b62f5c5d5cbc8ab5c8ece8faa87adcf7a0e70fa..0000000000000000000000000000000000000000 --- a/configs/phase2/clipRN50_clips.yml +++ /dev/null @@ -1,58 +0,0 @@ -caption_model: transformer -noamopt: true -noamopt_warmup: 20000 -label_smoothing: 0.0 -input_json: data/cocotalk.json -input_label_h5: data/cocotalk_label.h5 -input_fc_dir: data/cocotalk_clip_RN50_fc -input_att_dir: data/cocotalk_clip_RN50_att -input_clipscore_vis_dir: data/cocotalk_clipscore_vis -seq_per_img: 5 -batch_size: 160 -learning_rate: 0.0005 - -checkpoint_path: save/clipRN50_clips/clipRN50_clips - -# Notice: because I'm to lazy, I reuse the option name for RNNs to set the hyperparameters for transformer: -# N=num_layers -# d_model=input_encoding_size -# d_ff=rnn_size - -# will be ignored -num_layers: 6 -input_encoding_size: 512 -rnn_size: 2048 - -# Transformer config -N_enc: 6 -N_dec: 6 -d_model: 512 -d_ff: 2048 -num_att_heads: 8 -dropout: 0.1 - - -learning_rate_decay_start: 0 -scheduled_sampling_start: -1 -save_checkpoint_every: 3000 -language_eval: 1 -val_images_use: 5000 -max_epochs: 15 -train_sample_n: 5 - -REFORWARD: false - -# _BASE_: transformer.yml -reduce_on_plateau: false -noamopt: false -learning_rate: 0.000005 -learning_rate_decay_start: -1 - -self_critical_after: 15 -max_epochs: 40 - -verbose: false -precision: 32 - -use_clipscore: true -clipscore_reward_weight: 2.0 \ No newline at end of file diff --git a/configs/phase2/clipRN50_clips_grammar.yml b/configs/phase2/clipRN50_clips_grammar.yml deleted file mode 100644 index c9db26ff17158568d0f3d2a63837f3925dc007b8..0000000000000000000000000000000000000000 --- a/configs/phase2/clipRN50_clips_grammar.yml +++ /dev/null @@ -1,64 +0,0 @@ -caption_model: transformer -noamopt: true -noamopt_warmup: 20000 -label_smoothing: 0.0 -input_json: data/cocotalk.json -input_label_h5: data/cocotalk_label.h5 -input_fc_dir: data/cocotalk_clip_RN50_fc -input_att_dir: data/cocotalk_clip_RN50_att -input_clipscore_vis_dir: data/cocotalk_clipscore_vis -seq_per_img: 5 -batch_size: 160 -learning_rate: 0.0005 - -checkpoint_path: save/clipRN50_clips_grammar/clipRN50_clips_grammar - -use_multi_rewards: true -use_grammar: true -use_grammar_baseline: true -# clip_load_path: '/scratch-space/retrieval/save/clip_negative_text/clip_negative_text-epoch=10.ckpt' -clip_load_path: 'retrieval/save/clip_negative_text/clip_negative_text-epoch=12.ckpt' - -# Notice: because I'm to lazy, I reuse the option name for RNNs to set the hyperparameters for transformer: -# N=num_layers -# d_model=input_encoding_size -# d_ff=rnn_size - -# will be ignored -num_layers: 6 -input_encoding_size: 512 -rnn_size: 2048 - -# Transformer config -N_enc: 6 -N_dec: 6 -d_model: 512 -d_ff: 2048 -num_att_heads: 8 -dropout: 0.1 - - -learning_rate_decay_start: 0 -scheduled_sampling_start: -1 -save_checkpoint_every: 3000 -language_eval: 1 -val_images_use: 5000 -max_epochs: 15 -train_sample_n: 5 - -REFORWARD: false - -# _BASE_: transformer.yml -reduce_on_plateau: false -noamopt: false -learning_rate: 0.000005 -learning_rate_decay_start: -1 - -self_critical_after: 15 -max_epochs: 40 - -verbose: false -precision: 32 - -use_clipscore: true -clipscore_reward_weight: 2.0 \ No newline at end of file diff --git a/configs/phase2/transformer.yml b/configs/phase2/transformer.yml deleted file mode 100644 index 3dfa9f78b14a8fbec12a4d1177fa489942f861c7..0000000000000000000000000000000000000000 --- a/configs/phase2/transformer.yml +++ /dev/null @@ -1,41 +0,0 @@ -caption_model: transformer -noamopt: true -noamopt_warmup: 20000 -label_smoothing: 0.0 -input_json: data/cocotalk.json -input_label_h5: data/cocotalk_label.h5 -input_att_dir: data/cocotalk_att -seq_per_img: 5 -batch_size: 10 -learning_rate: 0.0005 - -checkpoint_path: ./save/trans_rn50_sc - -# Notice: because I'm to lazy, I reuse the option name for RNNs to set the hyperparameters for transformer: -# N=num_layers -# d_model=input_encoding_size -# d_ff=rnn_size - -# will be ignored -num_layers: 6 -input_encoding_size: 512 -rnn_size: 2048 - -# Transformer config -N_enc: 6 -N_dec: 6 -d_model: 512 -d_ff: 2048 -num_att_heads: 8 -dropout: 0.1 - - -learning_rate_decay_start: 0 -scheduled_sampling_start: -1 -save_checkpoint_every: 3000 -language_eval: 1 -val_images_use: 5000 -max_epochs: 15 -train_sample_n: 5 - -REFORWARD: false \ No newline at end of file diff --git a/data/README.md b/data/README.md deleted file mode 100644 index c786a9e85300c02f477a4d977cee587f35162b0d..0000000000000000000000000000000000000000 --- a/data/README.md +++ /dev/null @@ -1 +0,0 @@ -directory to store preprocessed files \ No newline at end of file diff --git a/retrieval/README.md b/retrieval/README.md deleted file mode 100644 index 2f5cce9ad9b93234fa5ac7a0e99f05868d883fd0..0000000000000000000000000000000000000000 --- a/retrieval/README.md +++ /dev/null @@ -1,5 +0,0 @@ -# Finetuning CLIP reward model - -```bash -python train_pl.py --cfg clip_negative_text --id clip_negative_text -``` \ No newline at end of file diff --git a/retrieval/caption_data.py b/retrieval/caption_data.py deleted file mode 100644 index 595a81ae5346937e5d9174401cd8a62e78946864..0000000000000000000000000000000000000000 --- a/retrieval/caption_data.py +++ /dev/null @@ -1,500 +0,0 @@ -from torch.utils.data import DataLoader, Dataset, Sampler -from pathlib import Path -import json -from multiprocessing import Pool -from tqdm import tqdm -from PIL import Image -import random -import numpy as np -import torch -import torchvision -import torchvision.transforms as T - -from torch.utils.data.distributed import DistributedSampler - -from transformers import T5Tokenizer, BertTokenizer, BertTokenizerFast, CLIPTokenizer - -import text_utils - -project_dir = Path(__file__).parent.resolve() -workspace_dir = project_dir.parent.parent -dataset_dir = workspace_dir.joinpath('datasets/').resolve() -# coco_dir = dataset_dir.joinpath('COCO') -# vg_dir = dataset_dir.joinpath('VG') -coco_img_dir = dataset_dir.joinpath('COCO/images/') -coco_data_dir = project_dir.parent.joinpath('CLIP-ViL/CLIP-ViL-Direct/caption/data/') -# coco_feature_dir = coco_dir.joinpath('features') - - -class COCORetrievalDataset(Dataset): - def __init__(self, split='karpathy_train', rank=-1, topk=-1, verbose=True, args=None, mode='train'): - super().__init__() - - self.topk = topk - self.verbose = verbose - self.args = args - self.rank = rank - self.mode = mode - - # Loading datasets to data - self.source = split - if self.verbose: - print('Data source: ', self.source) - - # if self.args.tokenizer is None: - # self.args.tokenizer = self.args.decoder_backbone - - # if 'bert' in self.args.tokenizer: - # self.tokenizer = BertTokenizerFast.from_pretrained( - # self.args.tokenizer, - # # max_length=self.args.max_text_length, - # # do_lower_case=self.args.do_lower_case - # ) - # elif 'clip' in self.args.tokenizer: - # self.tokenizer = CLIPTokenizer.from_pretrained( - # self.args.tokenizer, - # # max_length=self.args.max_text_length, - # # do_lower_case=self.args.do_lower_case - # ) - - self.tokenizer = CLIPTokenizer.from_pretrained( - self.args.tokenizer, - # max_length=self.args.max_text_length, - # do_lower_case=self.args.do_lower_case - ) - - with open(coco_data_dir.joinpath('cocotalk.json')) as f: - self.vocab = list(json.load(f)['ix_to_word'].values()) - popped = self.vocab.pop(-1) - assert popped == 'UNK' - if self.verbose: - print('vocab size: ', len(self.vocab)) - - - data_info_path = coco_data_dir.joinpath('dataset_coco.json') - with open(data_info_path) as f: - karpathy_data = json.load(f) - - split_rename = { - 'train': 'train', - 'restval': 'train', - 'val': 'val', - 'test': 'test' - } - - n_images = 0 - - data = [] - # self.vocab = set() - for datum in karpathy_data['images']: - re_split = split_rename[datum['split']] - - # if re_split == 'train': - # for d in datum['sentences']: - # self.vocab = self.vocab.union(set(d['tokens'])) - - if re_split != self.source.split('_')[-1]: - continue - - if re_split == 'train': - # for d in datum['sentences']: - # img_id = datum['filename'].split('.')[0] - # new_datum = { - # 'filename': datum['filename'], - # 'img_id': img_id, - # 'sent': d['raw'].strip(), - # 'targets': [d['raw'].strip() for d in datum['sentences']], - # 'is_train': True, - # 'cocoid': datum['cocoid'] - # } - # data.append(new_datum) - img_id = datum['filename'].split('.')[0] - new_datum = { - 'filename': datum['filename'], - 'img_id': img_id, - # 'sent': d['raw'], - # 'targets': [d['raw'].strip() for d in datum['sentences']], - 'targets': [" ".join(d['tokens']) for d in datum['sentences']], - 'is_train': True, - 'cocoid': datum['cocoid'] - } - data.append(new_datum) - - else: - img_id = datum['filename'].split('.')[0] - new_datum = { - 'filename': datum['filename'], - 'img_id': img_id, - # 'sent': d['raw'], - # 'targets': [d['raw'].strip() for d in datum['sentences']], - 'targets': [" ".join(d['tokens']) for d in datum['sentences']], - 'is_train': False, - 'cocoid': datum['cocoid'] - } - data.append(new_datum) - - n_images += 1 - - if self.verbose: - print(f"{self.source} has {n_images} images") - # print(f"Loaded {len(data)} data from", split) - - self.n_gpus = torch.cuda.device_count() - - if self.topk > 0: - data = data[:self.topk] - if self.verbose: - print(f"Use only {self.topk} data") - - self.data = data - - # if self.verbose: - # print("# all sentences:", len(self.data)) - - if self.args.load_feat: - # feat_dir = coco_dir.joinpath('' - # self.feat_loader = HybridLoader('/scratch-space/CLIP-ViL/CLIP-ViL-Direct/caption/data/cocotalk_clipscore_vis', ext='.npy', in_memory=False) - self.feat_loader = HybridLoader( - coco_data_dir.joinpath('cocotalk_clipscore_vis'), - ext='.npy', in_memory=False) - else: - if 'openai/clip' in self.args.encoder_backbone: - # from transformers import CLIPProcessor - # self.processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32", - # size=args.image_size, - # do_resize=True, - # do_center_crop=False, - # ) - # self.img_transform = lambda image: self.processor.feature_extractor( - # image, - # return_tensors='pt')['pixel_values'][0] - - self.image_mean = [0.48145466, 0.4578275, 0.40821073] - self.image_std = [0.26862954, 0.26130258, 0.27577711] - - # captioning - # self.img_transform = T.Compose([ - # T.Resize((self.args.image_size, self.args.image_size)) - # ]) - - # retrieval - self.img_transform = T.Compose([ - T.Resize(self.args.image_size, interpolation=T.functional.InterpolationMode.BICUBIC), - T.CenterCrop(self.args.image_size) - ]) - - self.img_tensor_transform = T.Compose([ - # T.RandomCrop(224), - # T.RandomHorizontalFlip(p=0.3), - T.ConvertImageDtype(torch.float), - T.Normalize(self.image_mean, self.image_std) - ] - ) - # elif 'google/vit' in self.args.encoder_backbone: - # self.image_mean = [0.5, 0.5, 0.5] - # self.image_std = [0.5, 0.5, 0.5] - - # self.img_transform = T.Compose([ - # # T.PILToTensor(), - # T.Resize((self.args.image_size, self.args.image_size)) - # ]) - - # self.img_tensor_transform = T.Compose([ - # # T.RandomCrop(224), - # # T.RandomHorizontalFlip(p=0.3), - # T.ConvertImageDtype(torch.float), - # T.Normalize(self.image_mean, self.image_std) - # ] - # ) - - def get_negative_text(self, text): - neg_type = random.choice(['repeat', 'remove', 'insert', 'swap', 'shuffle']) - - if neg_type == 'repeat': - text = text_utils.repeat(text) - elif neg_type == 'remove': - text = text_utils.remove(text) - elif neg_type == 'insert': - text = text_utils.insert(text, self.vocab) - elif neg_type == 'swap': - text = text_utils.swap(text, self.vocab) - elif neg_type == 'shuffle': - text = text_utils.shuffle(text) - - return text, neg_type - - def __len__(self): - return len(self.data) - - def __getitem__(self, idx): - datum = self.data[idx] - return self.process_datum(datum) - - def process_datum(self, datum): - out_dict = {} - - ###### Image ###### - - if self.args.load_feat: - cocoid = datum['cocoid'] - out_dict['cocoid'] = str(cocoid) - img_feat = self.feat_loader.get(str(cocoid)) - out_dict['img_feat'] = torch.from_numpy(img_feat) - - else: - img_id = datum['img_id'] - out_dict['img_id'] = img_id - - if 'train' in datum['filename']: - img_split = 'train2014' - elif 'val' in datum['filename']: - img_split = 'val2014' - img_path = coco_img_dir.joinpath(img_split).joinpath(datum['filename']).with_suffix('.jpg') - assert img_path.exists() - img_path = str(img_path) - out_dict['img_path'] = img_path - - img_tensor = torchvision.io.read_image(img_path) - # out_dict['img_tensor'] = img - - # img = Image.open(img_path).convert('RGB') - # img_tensor = torch.as_tensor(np.asarray(img)) - out_dict['img_tensor'] = self.img_transform(img_tensor) - # self.img_transform(img_tensor) - # out_dict['img_tensor'] = self.img_transform(img) - - ###### Text ##### - # if datum['is_train']: - # sent = datum['sent'].strip() - - sent = random.choice(datum['targets']) - - # target_ids = self.tokenizer.encode( - # sent, max_length=self.args.gen_max_length, truncation=True) - - # assert len(target_ids) <= self.args.gen_max_length, len(target_ids) - out_dict['sent'] = sent - # out_dict['target_ids'] = torch.LongTensor(target_ids) - # out_dict['target_length'] = len(target_ids) - - - # negative sample - neg_sent, neg_type = self.get_negative_text(sent) - - # neg_target_ids = self.tokenizer.encode( - # neg_sent, max_length=self.args.gen_max_length, truncation=True) - - # assert len(neg_target_ids) <= self.args.gen_max_length, len(neg_target_ids) - out_dict['neg_sent'] = neg_sent - out_dict['neg_type'] = neg_type - # out_dict['neg_target_ids'] = torch.LongTensor(neg_target_ids) - # out_dict['neg_target_length'] = len(neg_target_ids) - - - if 'targets' in datum: - out_dict['targets'] = datum['targets'] - - return out_dict - - def collate_fn(self, batch): - batch_entry = {} - - B = len(batch) - - # if 'target_ids' in batch[0]: - # T_W_L = max(entry['target_length'] for entry in batch) - # target_ids = torch.ones( - # B, T_W_L, dtype=torch.long) * self.tokenizer.pad_token_id - - # if 'target_ids' in batch[0]: - # T_W_L = max(entry['target_length'] for entry in batch) - # target_ids = torch.ones( - # B, T_W_L, dtype=torch.long) * self.tokenizer.pad_token_id - - - - targets = [] - img_ids = [] - img_paths = [] - - coco_ids = [] - - if self.args.load_feat: - img_feats = torch.zeros(B, 512, dtype=torch.float) - else: - # imgs = [] - img_tensor = torch.zeros(B, 3, self.args.image_size, self.args.image_size, dtype=torch.uint8) - - for i, entry in enumerate(batch): - - if self.args.load_feat: - coco_ids.append(entry['cocoid']) - img_feats[i] = entry['img_feat'] - - else: - - img_ids.append(entry['img_id']) - img_paths.append(entry['img_path']) - img_tensor[i] = entry['img_tensor'] - - # if 'target_ids' in entry: - # target_ids[i, :entry['target_length']] = entry['target_ids'] - - if 'targets' in entry: - targets.append(entry['targets']) - - if 'sent' in batch[0]: - # word_mask = target_ids != self.tokenizer.pad_token_id - # target_ids[~word_mask] = -100 - # batch_entry['target_ids'] = target_ids - - tokenized = self.tokenizer([entry['sent'] for entry in batch], truncation=True, padding=True, return_tensors='pt') - neg_tokenized = self.tokenizer([entry['neg_sent'] for entry in batch], truncation=True, padding=True, return_tensors='pt') - # sent, max_length=self.args.gen_max_length, truncation=True) - - batch_entry['text'] = (tokenized.input_ids, tokenized.attention_mask) - batch_entry['neg_text'] = (neg_tokenized.input_ids, neg_tokenized.attention_mask) - - - if self.args.load_feat: - batch_entry['coco_ids'] = coco_ids - batch_entry['img_feats'] = img_feats - - else: - - img_tensor = self.img_tensor_transform(img_tensor) - - batch_entry['img_id'] = img_ids - batch_entry['img_paths'] = img_paths - batch_entry['img_tensor'] = img_tensor - - batch_entry['targets'] = targets - - # print('batch created') - - # batch_entry['task'] = 'caption' - - return batch_entry - - -# def get_loader(args, split='karpathy_train', mode='train', -# batch_size=32, workers=4, distributed=False, gpu=0, -# topk=-1): - -# verbose = (gpu == 0) - -# dataset = COCORetrievalDataset( -# split, -# rank=gpu, -# topk=topk, -# verbose=verbose, -# args=args, -# mode=mode) - -# # if distributed: -# # sampler = DistributedSampler(dataset) -# # else: -# # sampler = None - -# if mode == 'train': -# loader = DataLoader( -# dataset, batch_size=batch_size, shuffle=(sampler is None), -# num_workers=workers, pin_memory=True, sampler=sampler, -# collate_fn=dataset.collate_fn) -# else: -# loader = DataLoader( -# dataset, -# batch_size=batch_size, shuffle=False, -# num_workers=workers, pin_memory=True, -# sampler=sampler, -# collate_fn=dataset.collate_fn, -# drop_last=False) - -# # if verbose: -# # loader.evaluator = COCOCaptionEvaluator() - -# # loader.task = 'caption' - -# return loader - - -# class COCOCaptionEvaluator: -# def __init__(self): -# import language_evaluation -# self.evaluator = language_evaluation.CocoEvaluator(verbose=False) - -# def evaluate(self, predicts, answers): - -# results = self.evaluator.run_evaluation(predicts, answers) - -# return results - -import six -import os -import h5py - -class HybridLoader: - """ - If db_path is a director, then use normal file loading - If lmdb, then load from lmdb - The loading method depend on extention. - - in_memory: if in_memory is True, we save all the features in memory - For individual np(y|z)s, we don't need to do that because the system will do this for us. - Should be useful for lmdb or h5. - (Copied this idea from vilbert) - """ - - def __init__(self, db_path, ext='.npy', in_memory=False): - self.db_path = db_path - self.ext = ext - if self.ext == '.npy': - self.loader = lambda x: np.load(six.BytesIO(x)) - else: - self.loader = lambda x: np.load(six.BytesIO(x))['feat'] - # if db_path.endswith('.lmdb'): - # self.db_type = 'lmdb' - # self.lmdb = lmdbdict(db_path, unsafe=True) - # self.lmdb._key_dumps = DUMPS_FUNC['ascii'] - # self.lmdb._value_loads = LOADS_FUNC['identity'] - # elif db_path.endswith('.pth'): # Assume a key,value dictionary - # self.db_type = 'pth' - # self.feat_file = torch.load(db_path) - # self.loader = lambda x: x - # print('HybridLoader: ext is ignored') - # elif db_path.endswith('h5'): - # self.db_type = 'h5' - # self.loader = lambda x: np.array(x).astype('float32') - # else: - # self.db_type = 'dir' - - self.in_memory = in_memory - if self.in_memory: - self.features = {} - - def get(self, key): - - # if self.in_memory and key in self.features: - # # We save f_input because we want to save the - # # compressed bytes to save memory - # f_input = self.features[key] - # elif self.db_type == 'lmdb': - # f_input = self.lmdb[key] - # elif self.db_type == 'pth': - # f_input = self.feat_file[key] - # elif self.db_type == 'h5': - # f_input = h5py.File(self.db_path, 'r')[key] - # else: - # f_input = open(os.path.join( - # self.db_path, key + self.ext), 'rb').read() - - f_input = open(os.path.join( - self.db_path, key + self.ext), 'rb').read() - - if self.in_memory and key not in self.features: - self.features[key] = f_input - - # load image - feat = self.loader(f_input) - - return feat diff --git a/retrieval/clip_model.py b/retrieval/clip_model.py deleted file mode 100644 index 83d35620683bd11d3c9e6ac38bf76acbcd364e21..0000000000000000000000000000000000000000 --- a/retrieval/clip_model.py +++ /dev/null @@ -1,350 +0,0 @@ -from transformers import CLIPModel, CLIPTokenizer -import os -import json -import argparse -from random import shuffle, seed -import string -# non-standard dependencies: -import h5py -from six.moves import cPickle -import numpy as np -import torch -import torchvision.models as models -import skimage.io - -from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize -from PIL import Image -from torch import nn - - -class CLIPScore(nn.Module): - def __init__(self, clipscore_w=2.5, image_size=224, mode='clip_s', use_grammar=False, joint_out=False): - super(CLIPScore, self).__init__() - # from transformers import CLIPModel, CLIPTokenizer - self.clip_model = CLIPModel.from_pretrained( - 'openai/clip-vit-base-patch32') - self.tokenizer = CLIPTokenizer.from_pretrained( - 'openai/clip-vit-base-patch32') - - self.clip_model.eval() - - self.clipscore_w = clipscore_w - - self.image_transform = self._transform(image_size) - - self.mode = mode - assert mode in ['clip_s', 'refclip_s'] - - self.use_grammar = use_grammar - self.joint_out = joint_out - - if self.use_grammar and self.joint_out is False: - self.grammar_score_head = nn.Sequential( - nn.Linear(self.clip_model.text_embed_dim, self.clip_model.projection_dim, bias=False), - nn.ReLU(), - nn.Linear(self.clip_model.projection_dim, 2, bias=False) - ) - - def _transform(self, n_px): - return Compose([ - Resize(n_px, interpolation=Image.BICUBIC), - CenterCrop(n_px), - lambda image: image.convert("RGB"), - ToTensor(), - Normalize((0.48145466, 0.4578275, 0.40821073), - (0.26862954, 0.26130258, 0.27577711)), - ]) - - def load_image(self, image_path): - image = Image.open(image_path) - return image - - # @torch.no_grad() - def image_extract(self, image): - if isinstance(image, str): - image = self.load_image(image) - if not isinstance(image, torch.Tensor): - image = self.image_transform(image) - - img_tensor = image.view(-1, 3, 224, 224) - device = next(self.clip_model.parameters()).device - img_tensor = img_tensor.to(device) - - clip_model = self.clip_model - - img_feat = clip_model.vision_model(img_tensor).pooler_output - img_feat = clip_model.visual_projection(img_feat) - img_feat = img_feat / img_feat.norm(dim=-1, keepdim=True) - - return img_feat - - # @torch.no_grad() - def text_extract(self, text, prompt="A photo depicts", proj_norm=True): - if isinstance(text, str): - text_batch = [" ".join([prompt, text])] - elif isinstance(text, list): - text_batch = [" ".join([prompt, txt]) for txt in text] - - if isinstance(text, tuple) and isinstance(text[0], torch.Tensor): - input_ids, attention_mask = text - else: - input_text = text_batch - - tokenized = self.tokenizer( - input_text, return_tensors='pt', padding=True) - - input_ids = tokenized.input_ids - attention_mask = tokenized.attention_mask - - clip_model = self.clip_model - device = next(self.clip_model.parameters()).device - input_ids = input_ids.to(device) - attention_mask = attention_mask.to(device) - - text_feat = clip_model.text_model(input_ids, attention_mask).pooler_output - - if proj_norm: - text_feat = clip_model.text_projection(text_feat) - text_feat = text_feat / text_feat.norm(dim=-1, keepdim=True) - - return text_feat - - # @torch.no_grad() - def calc_clip_s(self, img_feat, text_feat): - return self.clipscore_w * torch.relu((img_feat * text_feat).sum(dim=-1)) - - # @torch.no_grad() - def calc_refclip_s(self, img_feat=None, text_feat=None, ref_text_feat=None, ref_text_mask=None, clip_s=None): - - if clip_s is None: - clip_s = self.calc_clip_s(img_feat, text_feat) - - B, dim = img_feat.size() - - ref_text_feat = ref_text_feat.view(B, -1, dim) - - K = ref_text_feat.size(1) - - text_feat = text_feat.view(B, 1, dim).expand(-1, K, -1) - assert ref_text_feat.size() == text_feat.size( - ), (ref_text_feat.size(), text_feat.size()) - - ref_score = self.calc_clip_s(text_feat, ref_text_feat) - if ref_text_mask is not None: - if not isinstance(ref_text_mask, torch.Tensor): - ref_text_mask = torch.tensor( - ref_text_mask, dtype=ref_score.dtype, device=ref_score.device) - ref_score = ref_score.view(B, K) * ref_text_mask.view(B, K) - - ref_score = ref_score.view(B, K).max(dim=1).values - - assert clip_s.size() == (B,) - assert clip_s.size() == ref_score.size() - - # harmonic mean - refclip_s = 2 / (1 / clip_s + 1 / ref_score) - return refclip_s - - # # @torch.no_grad() - # def forward(self, - # images=None, text=None, - # img_feat=None, text_feat=None, - # ref_text=None, ref_text_feat=None, ref_text_mask=None, - # prompt="A photo depicts", - # mode=None): - # if img_feat is None: - # img_feat = self.image_extract(images) - # img_feat = img_feat.view(-1, 512) - - # if text_feat is None: - # text_feat = self.text_extract(text, prompt=prompt) - # text_feat = text_feat.view(-1, 512) - - # if mode is None: - # mode = self.mode - # assert mode in ['clip_s', 'refclip_s'] - - # if mode == 'clip_s': - # clip_s = self.calc_clip_s(img_feat, text_feat) - # return clip_s - # elif mode == 'refclip_s': - # if ref_text_feat is None: - # ref_text_feat = self.text_extract(ref_text, prompt=prompt) - # ref_text_feat = ref_text_feat.view(-1, 512) - - # refclip_s = self.calc_refclip_s( - # img_feat, text_feat, ref_text_feat, ref_text_mask=ref_text_mask) - # return refclip_s - - - def train_step(self, - images=None, text=None, - img_feat=None, text_feat=None, - neg_text=None, neg_text_feat=None, - # ref_text=None, ref_text_feat=None, ref_text_mask=None, - prompt="A photo depicts", - # return_loss=True, - **kwargs): - - if img_feat is None: - img_feat = self.image_extract(images) - img_feat = img_feat.view(-1, 512) - - B = img_feat.size(0) - - if self.joint_out: - pos_text_feat = self.text_extract(text, prompt=prompt, proj_norm=False).view(B, 512) - neg_text_feat = self.text_extract(neg_text, prompt=prompt, proj_norm=False).view(-1, 512) - neg_B = neg_text_feat.size(0) - - # [B+neg_B, 512] - text_feat = torch.cat([pos_text_feat, neg_text_feat], dim=0) - - text_cont_feat = self.clip_model.text_projection(text_feat) - text_cont_feat = text_cont_feat / text_cont_feat.norm(dim=-1, keepdim=True) - - text_cont_feat = text_cont_feat.view(B+neg_B, 512) - - logit_scale = self.clip_model.logit_scale.exp() - - # [B+neg_B * B] - logits_per_text = torch.matmul(text_cont_feat, img_feat.t()) * logit_scale - - # image-to-text label: positive text - caption_loss = -torch.diag(nn.functional.log_softmax(logits_per_text, dim=0)[:B]).mean() - - # calculate text-to-image only on positive text - image_loss = -torch.diag(nn.functional.log_softmax(logits_per_text[:B], dim=1)).mean() - - clip_loss = (caption_loss + image_loss) / 2.0 - - out = { - 'clip_loss': clip_loss, - 'img_feat': img_feat, - 'text_feat': text_cont_feat[:B].detach(), - # 'neg_text_feat': neg_text_feat, - } - - return out - - - else: - if text_feat is None: - text_feat = self.text_extract(text, prompt=prompt, proj_norm=False) - - text_cont_feat = self.clip_model.text_projection(text_feat) - text_cont_feat = text_cont_feat / \ - text_cont_feat.norm(dim=-1, keepdim=True) - - text_cont_feat = text_cont_feat.view(B, 512) - - - # cosine similarity as logits - logit_scale = self.clip_model.logit_scale.exp() - logits_per_text = torch.matmul(text_cont_feat, img_feat.t()) * logit_scale - # logits_per_image = logits_per_text.T - - clip_loss = clip_loss_fn(logits_per_text) - - - # negative sampling - pos_text_feat = text_feat.view(B, 512) - neg_text_feat = self.text_extract(neg_text, prompt=prompt, proj_norm=False).view(B, 512) - - grammar_text_feat = torch.cat([pos_text_feat, neg_text_feat], dim=0) - - # 2B, 1 - grammar_text_logit = self.grammar_score_head(grammar_text_feat) - grammar_labels = torch.LongTensor([1] * B + [0] * B).to(grammar_text_logit.device).view(2 * B) - - grammar_loss = torch.nn.functional.cross_entropy(grammar_text_logit, grammar_labels) - - grammar_pred = grammar_text_logit.argmax(dim=1, keepdim=False) - grammar_pos_pred = grammar_pred[:B] - grammar_neg_pred = grammar_pred[B:] - # grammar_acc = (grammar_pred == grammar_labels).float().mean() - - out = { - 'clip_loss': clip_loss, - 'grammar_loss': grammar_loss, - 'img_feat': img_feat, - 'text_feat': text_cont_feat, - 'neg_text_feat': neg_text_feat, - 'grammar_pos_pred': grammar_pos_pred, - 'grammar_neg_pred': grammar_neg_pred, - } - - return out - - def train_step_old(self, - images=None, text=None, - img_feat=None, text_feat=None, - neg_text=None, neg_text_feat=None, - # ref_text=None, ref_text_feat=None, ref_text_mask=None, - prompt="A photo depicts", - # return_loss=True, - **kwargs): - - if img_feat is None: - img_feat = self.image_extract(images) - img_feat = img_feat.view(-1, 512) - - B = img_feat.size(0) - - - - if text_feat is None: - text_feat = self.text_extract(text, prompt=prompt, proj_norm=False) - - text_cont_feat = self.clip_model.text_projection(text_feat) - text_cont_feat = text_cont_feat / text_cont_feat.norm(dim=-1, keepdim=True) - text_cont_feat = text_cont_feat.view(B, 512) - - # cosine similarity as logits - logit_scale = self.clip_model.logit_scale.exp() - logits_per_text = torch.matmul(text_cont_feat, img_feat.t()) * logit_scale - # logits_per_image = logits_per_text.T - - clip_loss = clip_loss_fn(logits_per_text) - - - # negative sampling - pos_text_feat = text_feat.view(B, 512) - neg_text_feat = self.text_extract(neg_text, prompt=prompt, proj_norm=False).view(B, 512) - - grammar_text_feat = torch.cat([pos_text_feat, neg_text_feat], dim=0) - - # 2B, 1 - grammar_text_logit = self.grammar_score_head(grammar_text_feat) - grammar_labels = torch.LongTensor([1] * B + [0] * B).to(grammar_text_logit.device).view(2 * B) - - grammar_loss = torch.nn.functional.cross_entropy(grammar_text_logit, grammar_labels) - - grammar_pred = grammar_text_logit.argmax(dim=1, keepdim=False) - grammar_pos_pred = grammar_pred[:B] - grammar_neg_pred = grammar_pred[B:] - # grammar_acc = (grammar_pred == grammar_labels).float().mean() - - out = { - 'clip_loss': clip_loss, - 'grammar_loss': grammar_loss, - 'img_feat': img_feat, - 'text_feat': text_cont_feat, - 'neg_text_feat': neg_text_feat, - 'grammar_pos_pred': grammar_pos_pred, - 'grammar_neg_pred': grammar_neg_pred, - } - - return out - -# contrastive loss function, adapted from -# https://sachinruk.github.io/blog/pytorch/pytorch%20lightning/loss%20function/gpu/2021/03/07/CLIP.html -def contrastive_loss(logits: torch.Tensor, dim: int) -> torch.Tensor: - neg_ce = torch.diag(nn.functional.log_softmax(logits, dim=dim)) - return -neg_ce.mean() - - -def clip_loss_fn(similarity: torch.Tensor) -> torch.Tensor: - caption_loss = contrastive_loss(similarity, dim=0) - image_loss = contrastive_loss(similarity, dim=1) - return (caption_loss + image_loss) / 2.0 diff --git a/retrieval/configs/clip_negative_text.yaml b/retrieval/configs/clip_negative_text.yaml deleted file mode 100644 index 21dbeea2221a183bd4599a2f85ad039afc5ff44f..0000000000000000000000000000000000000000 --- a/retrieval/configs/clip_negative_text.yaml +++ /dev/null @@ -1,14 +0,0 @@ -checkpoint_dir: ./save/clip_negative_text/ - -losses_log_every: 25 -precision: 32 -load_feat: true -data_in_memory: false - -batch_size: 1600 -valid_batch_size: 200 -clip_grad_norm: 0 - -epochs: 30 -use_grammar: true -joint_out: false \ No newline at end of file diff --git a/retrieval/param.py b/retrieval/param.py deleted file mode 100644 index 45feaa691759a7f0a04080cd397764e6e5362a36..0000000000000000000000000000000000000000 --- a/retrieval/param.py +++ /dev/null @@ -1,209 +0,0 @@ -import argparse -import random - -import numpy as np -import torch - -import pprint -import yaml - - -def str2bool(v): - if v.lower() in ('yes', 'true', 't', 'y', '1'): - return True - elif v.lower() in ('no', 'false', 'f', 'n', '0'): - return False - else: - raise argparse.ArgumentTypeError('Boolean value expected.') - - -def is_interactive(): - import __main__ as main - return not hasattr(main, '__file__') - - -def get_optimizer(optim, verbose=False): - # Bind the optimizer - if optim == 'rms': - if verbose: - print("Optimizer: Using RMSProp") - optimizer = torch.optim.RMSprop - elif optim == 'adam': - if verbose: - print("Optimizer: Using Adam") - optimizer = torch.optim.Adam - elif optim == 'adamw': - if verbose: - print("Optimizer: Using AdamW") - # optimizer = torch.optim.AdamW - optimizer = 'adamw' - elif optim == 'adamax': - if verbose: - print("Optimizer: Using Adamax") - optimizer = torch.optim.Adamax - elif optim == 'sgd': - if verbose: - print("Optimizer: SGD") - optimizer = torch.optim.SGD - else: - assert False, "Please add your optimizer %s in the list." % optim - - return optimizer - - -def parse_args(parse=True, **optional_kwargs): - parser = argparse.ArgumentParser() - - parser.add_argument('--seed', type=int, default=9595, help='random seed') - - # Data Splits - parser.add_argument("--train", default='karpathy_train') - parser.add_argument("--valid", default='karpathy_val') - parser.add_argument("--test", default='karpathy_test') - # parser.add_argument('--test_only', action='store_true') - - # Quick experiments - parser.add_argument('--train_topk', type=int, default=-1) - parser.add_argument('--valid_topk', type=int, default=-1) - - # Checkpoint - parser.add_argument('--output', type=str, default='snap/test') - parser.add_argument('--load', type=str, default=None, help='Load the model (usually the fine-tuned model).') - parser.add_argument('--from_scratch', action='store_true') - - # CPU/GPU - parser.add_argument("--multiGPU", action='store_const', default=False, const=True) - parser.add_argument('--fp16', action='store_true') - parser.add_argument("--distributed", action='store_true') - parser.add_argument("--num_workers", default=0, type=int) - parser.add_argument('--local_rank', type=int, default=-1) - # parser.add_argument('--rank', type=int, default=-1) - - # Model Config - # parser.add_argument('--encoder_backbone', type=str, default='openai/clip-vit-base-patch32') - # parser.add_argument('--decoder_backbone', type=str, default='bert-base-uncased') - parser.add_argument('--tokenizer', type=str, default='openai/clip-vit-base-patch32') - - # parser.add_argument('--position_embedding_type', type=str, default='absolute') - - # parser.add_argument('--encoder_transform', action='store_true') - - parser.add_argument('--max_text_length', type=int, default=40) - - # parser.add_argument('--image_size', type=int, default=224) - # parser.add_argument('--patch_size', type=int, default=32) - - # parser.add_argument('--decoder_num_layers', type=int, default=12) - - # Training - parser.add_argument('--batch_size', type=int, default=256) - parser.add_argument('--valid_batch_size', type=int, default=None) - - parser.add_argument('--optim', default='adamw') - - parser.add_argument('--warmup_ratio', type=float, default=0.05) - parser.add_argument('--weight_decay', type=float, default=0.01) - parser.add_argument('--clip_grad_norm', type=float, default=-1.0) - parser.add_argument('--gradient_accumulation_steps', type=int, default=1) - parser.add_argument('--lr', type=float, default=1e-4) - parser.add_argument('--adam_eps', type=float, default=1e-6) - parser.add_argument('--adam_beta1', type=float, default=0.9) - parser.add_argument('--adam_beta2', type=float, default=0.999) - - parser.add_argument('--epochs', type=int, default=20) - # parser.add_argument('--dropout', type=float, default=0.1) - - - # Inference - # parser.add_argument('--num_beams', type=int, default=1) - # parser.add_argument('--gen_max_length', type=int, default=20) - - parser.add_argument('--start_from', type=str, default=None) - - # Data - # parser.add_argument('--do_lower_case', type=str2bool, default=None) - - # parser.add_argument('--prefix', type=str, default=None) - - - # COCO Caption - # parser.add_argument('--no_prefix', action='store_true') - - parser.add_argument('--no_cls', action='store_true') - - parser.add_argument('--cfg', type=str, default=None) - parser.add_argument('--id', type=str, default=None) - - # Etc. - parser.add_argument('--comment', type=str, default='') - parser.add_argument("--dry", action='store_true') - - # Parse the arguments. - if parse: - args = parser.parse_args() - # For interative engironmnet (ex. jupyter) - else: - args = parser.parse_known_args()[0] - - loaded_kwargs = {} - if args.cfg is not None: - cfg_path = f'configs/{args.cfg}.yaml' - with open(cfg_path, 'r') as f: - loaded_kwargs = yaml.safe_load(f) - - # Namespace => Dictionary - parsed_kwargs = vars(args) - parsed_kwargs.update(optional_kwargs) - - kwargs = {} - kwargs.update(parsed_kwargs) - kwargs.update(loaded_kwargs) - - args = Config(**kwargs) - - # Bind optimizer class. - verbose = False - args.optimizer = get_optimizer(args.optim, verbose=verbose) - - # Set seeds - torch.manual_seed(args.seed) - random.seed(args.seed) - np.random.seed(args.seed) - - return args - - -class Config(object): - def __init__(self, **kwargs): - """Configuration Class: set kwargs as class attributes with setattr""" - for k, v in kwargs.items(): - setattr(self, k, v) - - @property - def config_str(self): - return pprint.pformat(self.__dict__) - - def __repr__(self): - """Pretty-print configurations in alphabetical order""" - config_str = 'Configurations\n' - config_str += self.config_str - return config_str - - # def update(self, **kwargs): - # for k, v in kwargs.items(): - # setattr(self, k, v) - - # def save(self, path): - # with open(path, 'w') as f: - # yaml.dump(self.__dict__, f, default_flow_style=False) - - # @classmethod - # def load(cls, path): - # with open(path, 'r') as f: - # kwargs = yaml.load(f) - - # return Config(**kwargs) - - -if __name__ == '__main__': - args = parse_args(True) diff --git a/retrieval/pth_loader.py b/retrieval/pth_loader.py deleted file mode 100644 index 388301edd763d54d95675ca2ed6eb502f77e1eb1..0000000000000000000000000000000000000000 --- a/retrieval/pth_loader.py +++ /dev/null @@ -1,334 +0,0 @@ -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import json -import h5py -from lmdbdict import lmdbdict -from lmdbdict.methods import DUMPS_FUNC, LOADS_FUNC -import os -import numpy as np -import numpy.random as npr -import random - -import torch -import torch.utils.data as data - -import multiprocessing -import six - -verbose = True -# import torch -# if torch.cuda.current_device() in [0, -1]: -if 'LOCAL_RANK' in os.environ and os.environ['LOCAL_RANK'] != '0': - verbose = False - -class HybridLoader: - """ - If db_path is a director, then use normal file loading - If lmdb, then load from lmdb - The loading method depend on extention. - - in_memory: if in_memory is True, we save all the features in memory - For individual np(y|z)s, we don't need to do that because the system will do this for us. - Should be useful for lmdb or h5. - (Copied this idea from vilbert) - """ - def __init__(self, db_path, ext, in_memory=False): - self.db_path = db_path - self.ext = ext - if self.ext == '.npy': - self.loader = lambda x: np.load(six.BytesIO(x)) - else: - self.loader = lambda x: np.load(six.BytesIO(x))['feat'] - if db_path.endswith('.lmdb'): - self.db_type = 'lmdb' - self.lmdb = lmdbdict(db_path, unsafe=True) - self.lmdb._key_dumps = DUMPS_FUNC['ascii'] - self.lmdb._value_loads = LOADS_FUNC['identity'] - elif db_path.endswith('.pth'): # Assume a key,value dictionary - self.db_type = 'pth' - self.feat_file = torch.load(db_path) - self.loader = lambda x: x - print('HybridLoader: ext is ignored') - elif db_path.endswith('h5'): - self.db_type = 'h5' - self.loader = lambda x: np.array(x).astype('float32') - else: - self.db_type = 'dir' - - self.in_memory = in_memory - if self.in_memory: - self.features = {} - - def get(self, key): - - if self.in_memory and key in self.features: - # We save f_input because we want to save the - # compressed bytes to save memory - f_input = self.features[key] - elif self.db_type == 'lmdb': - f_input = self.lmdb[key] - elif self.db_type == 'pth': - f_input = self.feat_file[key] - elif self.db_type == 'h5': - f_input = h5py.File(self.db_path, 'r')[key] - else: - f_input = open(os.path.join(self.db_path, key + self.ext), 'rb').read() - - if self.in_memory and key not in self.features: - self.features[key] = f_input - - # load image - feat = self.loader(f_input) - - return feat - -class CaptionDataset(data.Dataset): - - def get_vocab_size(self): - return self.vocab_size - - def get_vocab(self): - return self.ix_to_word - - def get_seq_length(self): - return self.seq_length - - def __init__(self, opt): - self.opt = opt - self.seq_per_img = opt.seq_per_img - - # feature related options - self.use_fc = getattr(opt, 'use_fc', True) - self.use_att = getattr(opt, 'use_att', True) - self.use_box = getattr(opt, 'use_box', 0) - self.norm_att_feat = getattr(opt, 'norm_att_feat', 0) - self.norm_box_feat = getattr(opt, 'norm_box_feat', 0) - - # load the json file which contains additional information about the dataset - if verbose: - print('DataLoader loading json file: ', opt.input_json) - self.info = json.load(open(self.opt.input_json)) - if 'ix_to_word' in self.info: - self.ix_to_word = self.info['ix_to_word'] - self.vocab_size = len(self.ix_to_word) - if verbose: - print('vocab size is ', self.vocab_size) - - # open the hdf5 file - if verbose: - print('DataLoader loading h5 file: ', opt.input_fc_dir, opt.input_att_dir, opt.input_box_dir, opt.input_label_h5) - """ - Setting input_label_h5 to none is used when only doing generation. - For example, when you need to test on coco test set. - """ - if self.opt.input_label_h5 != 'none': - self.h5_label_file = h5py.File(self.opt.input_label_h5, 'r', driver='core') - # load in the sequence data - seq_size = self.h5_label_file['labels'].shape - self.label = self.h5_label_file['labels'][:] - self.seq_length = seq_size[1] - if verbose: - print('max sequence length in data is', self.seq_length) - # load the pointers in full to RAM (should be small enough) - self.label_start_ix = self.h5_label_file['label_start_ix'][:] - self.label_end_ix = self.h5_label_file['label_end_ix'][:] - else: - self.seq_length = 1 - - self.data_in_memory = getattr(opt, 'data_in_memory', False) - self.fc_loader = HybridLoader(self.opt.input_fc_dir, '.npy', in_memory=self.data_in_memory) - self.att_loader = HybridLoader(self.opt.input_att_dir, '.npz', in_memory=self.data_in_memory) - self.box_loader = HybridLoader(self.opt.input_box_dir, '.npy', in_memory=self.data_in_memory) - - self.use_clipscore = getattr(opt, 'use_clipscore', False) - if self.use_clipscore: - self.clipscore_loader = HybridLoader(self.opt.input_clipscore_vis_dir, '.npy', in_memory=self.data_in_memory) - - - self.num_images = len(self.info['images']) # self.label_start_ix.shape[0] - if verbose: - print('read %d image features' %(self.num_images)) - - # separate out indexes for each of the provided splits - self.split_ix = {'train': [], 'val': [], 'test': []} - for ix in range(len(self.info['images'])): - img = self.info['images'][ix] - if not 'split' in img: - self.split_ix['train'].append(ix) - self.split_ix['val'].append(ix) - self.split_ix['test'].append(ix) - elif img['split'] == 'train': - self.split_ix['train'].append(ix) - elif img['split'] == 'val': - self.split_ix['val'].append(ix) - elif img['split'] == 'test': - self.split_ix['test'].append(ix) - elif opt.train_only == 0: # restval - self.split_ix['train'].append(ix) - - if verbose: - print('assigned %d images to split train' %len(self.split_ix['train'])) - print('assigned %d images to split val' %len(self.split_ix['val'])) - print('assigned %d images to split test' %len(self.split_ix['test'])) - - def get_captions(self, ix, seq_per_img): - # fetch the sequence labels - ix1 = self.label_start_ix[ix] - 1 #label_start_ix starts from 1 - ix2 = self.label_end_ix[ix] - 1 - ncap = ix2 - ix1 + 1 # number of captions available for this image - assert ncap > 0, 'an image does not have any label. this can be handled but right now isn\'t' - - if ncap < seq_per_img: - # we need to subsample (with replacement) - seq = np.zeros([seq_per_img, self.seq_length], dtype = 'int') - for q in range(seq_per_img): - ixl = random.randint(ix1,ix2) - seq[q, :] = self.label[ixl, :self.seq_length] - else: - ixl = random.randint(ix1, ix2 - seq_per_img + 1) - seq = self.label[ixl: ixl + seq_per_img, :self.seq_length] - - return seq - - def collate_func(self, batch): - seq_per_img = self.seq_per_img - - fc_batch = [] - att_batch = [] - label_batch = [] - - clip_vis_feat_batch = [] - - wrapped = False - - infos = [] - gts = [] - - for sample in batch: - # fetch image - if self.use_clipscore: - tmp_fc, tmp_att, tmp_seq, \ - ix, tmp_clip_vis_feat = sample - - clip_vis_feat_batch.append(tmp_clip_vis_feat) - else: - tmp_fc, tmp_att, tmp_seq, \ - ix = sample - - fc_batch.append(tmp_fc) - att_batch.append(tmp_att) - - tmp_label = np.zeros([seq_per_img, self.seq_length + 2], dtype = 'int') - if hasattr(self, 'h5_label_file'): - # if there is ground truth - tmp_label[:, 1 : self.seq_length + 1] = tmp_seq - label_batch.append(tmp_label) - - # Used for reward evaluation - if hasattr(self, 'h5_label_file'): - # if there is ground truth - gts.append(self.label[self.label_start_ix[ix] - 1: self.label_end_ix[ix]]) - else: - gts.append([]) - - # record associated info as well - info_dict = {} - info_dict['ix'] = ix - info_dict['id'] = self.info['images'][ix]['id'] - info_dict['file_path'] = self.info['images'][ix].get('file_path', '') - infos.append(info_dict) - - # #sort by att_feat length - # fc_batch, att_batch, label_batch, gts, infos = \ - # zip(*sorted(zip(fc_batch, att_batch, np.vsplit(label_batch, batch_size), gts, infos), key=lambda x: len(x[1]), reverse=True)) - if self.use_clipscore: - fc_batch, att_batch, label_batch, clip_vis_feat_batch, gts, infos = \ - zip(*sorted(zip(fc_batch, att_batch, label_batch, clip_vis_feat_batch, gts, infos), key=lambda x: 0, reverse=True)) - else: - fc_batch, att_batch, label_batch, gts, infos = \ - zip(*sorted(zip(fc_batch, att_batch, label_batch, gts, infos), key=lambda x: 0, reverse=True)) - data = {} - data['fc_feats'] = np.stack(fc_batch) - # merge att_feats - max_att_len = max([_.shape[0] for _ in att_batch]) - data['att_feats'] = np.zeros([len(att_batch), max_att_len, att_batch[0].shape[1]], dtype = 'float32') - for i in range(len(att_batch)): - data['att_feats'][i, :att_batch[i].shape[0]] = att_batch[i] - data['att_masks'] = np.zeros(data['att_feats'].shape[:2], dtype='float32') - for i in range(len(att_batch)): - data['att_masks'][i, :att_batch[i].shape[0]] = 1 - # set att_masks to None if attention features have same length - if data['att_masks'].sum() == data['att_masks'].size: - data['att_masks'] = None - - if self.use_clipscore: - data['clip_vis_feats'] = np.stack(clip_vis_feat_batch) - - data['labels'] = np.vstack(label_batch) - # generate mask - nonzeros = np.array(list(map(lambda x: (x != 0).sum()+2, data['labels']))) - mask_batch = np.zeros([data['labels'].shape[0], self.seq_length + 2], dtype = 'float32') - for ix, row in enumerate(mask_batch): - row[:nonzeros[ix]] = 1 - data['masks'] = mask_batch - data['labels'] = data['labels'].reshape(len(batch), seq_per_img, -1) - data['masks'] = data['masks'].reshape(len(batch), seq_per_img, -1) - - data['gts'] = gts # all ground truth captions of each images - data['infos'] = infos - - data = {k:torch.from_numpy(v) if type(v) is np.ndarray else v for k,v in data.items()} # Turn all ndarray to torch tensor - - return data - - def __getitem__(self, ix): - """This function returns a tuple that is further passed to collate_fn - """ - if self.use_att: - att_feat = self.att_loader.get(str(self.info['images'][ix]['id'])) - # Reshape to K x C - att_feat = att_feat.reshape(-1, att_feat.shape[-1]) - if self.norm_att_feat: - att_feat = att_feat / np.linalg.norm(att_feat, 2, 1, keepdims=True) - if self.use_box: - box_feat = self.box_loader.get(str(self.info['images'][ix]['id'])) - # devided by image width and height - x1,y1,x2,y2 = np.hsplit(box_feat, 4) - h,w = self.info['images'][ix]['height'], self.info['images'][ix]['width'] - box_feat = np.hstack((x1/w, y1/h, x2/w, y2/h, (x2-x1)*(y2-y1)/(w*h))) # question? x2-x1+1?? - if self.norm_box_feat: - box_feat = box_feat / np.linalg.norm(box_feat, 2, 1, keepdims=True) - att_feat = np.hstack([att_feat, box_feat]) - # sort the features by the size of boxes - att_feat = np.stack(sorted(att_feat, key=lambda x:x[-1], reverse=True)) - else: - att_feat = np.zeros((0,0), dtype='float32') - if self.use_fc: - try: - fc_feat = self.fc_loader.get(str(self.info['images'][ix]['id'])) - except: - # Use average of attention when there is no fc provided (For bottomup feature) - fc_feat = att_feat.mean(0) - else: - fc_feat = np.zeros((0), dtype='float32') - if hasattr(self, 'h5_label_file'): - seq = self.get_captions(ix, self.seq_per_img) - else: - seq = None - - if self.use_clipscore: - clip_vis_feat = self.clipscore_loader.get( - str(self.info['images'][ix]['id'])) - - return (fc_feat, - att_feat, seq, - ix, clip_vis_feat) - - return (fc_feat, - att_feat, seq, - ix) - - def __len__(self): - return len(self.info['images']) diff --git a/retrieval/text_utils.py b/retrieval/text_utils.py deleted file mode 100644 index 51f981054b41a945656f9e619c722e09de198bf7..0000000000000000000000000000000000000000 --- a/retrieval/text_utils.py +++ /dev/null @@ -1,74 +0,0 @@ -import random - -def repeat(text, n_max_gram=3, n_max_repeat=3): - """repeat n-grams""" - tokens = text.split() - - n_gram = random.randint(1, n_max_gram) - - repeat_token_idx = random.randint(0, len(tokens) - n_gram) - - repeated_tokens = tokens[repeat_token_idx:repeat_token_idx+n_gram] - - n_repeat = random.randint(1, n_max_repeat) - for _ in range(n_repeat): - insert_idx = random.randint(0, len(tokens)) - tokens = tokens[:insert_idx] + \ - repeated_tokens + tokens[insert_idx:] - - new_text = " ".join(tokens) - return new_text - -def remove(text, n_max_gram=3): - """remove n-grams""" - tokens = text.split() - - n_gram = random.randint(1, n_max_gram) - - remove_token_idx = random.randint(0, len(tokens) - n_gram) - - tokens = tokens[:remove_token_idx] + tokens[remove_token_idx + n_gram:] - - new_text = " ".join(tokens) - return new_text - -def insert(text, vocab, n_max_tokens=3): - """Insert tokens""" - tokens = text.split() - - n_insert_token = random.randint(1, n_max_tokens) - - for _ in range(n_insert_token): - insert_token_idx = random.randint(0, len(tokens) - 1) - insert_token = random.choice(vocab) - tokens = tokens[:insert_token_idx] + [insert_token] + tokens[insert_token_idx:] - - new_text = " ".join(tokens) - return new_text - -def swap(text, vocab, n_max_tokens=3): - """Swap tokens""" - tokens = text.split() - - n_swap_tokens = random.randint(1, n_max_tokens) - - for _ in range(n_swap_tokens): - swap_token_idx = random.randint(0, len(tokens) - 1) - - swap_token = random.choice(vocab) - while swap_token == tokens[swap_token_idx]: - swap_token = random.choice(vocab) - - tokens[swap_token_idx] = swap_token - - new_text = " ".join(tokens) - return new_text - -def shuffle(text): - """shuffle tokens""" - tokens = text.split() - - random.shuffle(tokens) - - new_text = " ".join(tokens) - return new_text diff --git a/retrieval/train_pl.py b/retrieval/train_pl.py deleted file mode 100644 index 28f1330c945dd4b083a0adff287e4020b2433a4d..0000000000000000000000000000000000000000 --- a/retrieval/train_pl.py +++ /dev/null @@ -1,661 +0,0 @@ -from ast import parse -import torch -import torch.nn as nn -import torch.nn.functional as F -import torch.optim as optim - -import numpy as np - -import time -import os -from collections import defaultdict - -# import captioning.utils.opts as opts -# import captioning.models as models -# from captioning.data.pth_loader import CaptionDataset -# import captioning.utils.eval_utils as eval_utils -# import captioning.utils.misc as utils -# from captioning.utils.rewards import init_scorer, get_self_critical_reward -# from captioning.modules.loss_wrapper import LossWrapper - -from clip_model import CLIPScore -from caption_data import COCORetrievalDataset - -import pytorch_lightning as pl - -import detectron2.utils.comm as d2comm -from detectron2.utils.env import seed_all_rng -seed_all_rng(1234) - - -class LitModel(pl.LightningModule): - def __init__(self, opt): - super().__init__() - self.opt = opt - self.args = args - # Intilaize dataset - # self.dataset = CaptionDataset(opt) - - # self.dataset = - - # opt.vocab_size = self.dataset.vocab_size - # opt.seq_length = self.dataset.seq_length - # self.batch_size = opt.batch_size - - # Build model - # opt.vocab = self.dataset.get_vocab() - # model = models.setup(opt) - # print(model) - # del opt.vocab - - # wrapper with loss in it. - # lw_model = LossWrapper(model, opt) - - self.model = CLIPScore(use_grammar=opt.use_grammar, joint_out=opt.joint_out) - # self.lw_model = lw_model - - for p in self.model.clip_model.vision_model.parameters(): - p.requires_grad = False - for p in self.model.clip_model.visual_projection.parameters(): - p.requires_grad = False - - # self.struc_flag = None - # self.sc_flag = None - - - def forward(self, *args, **kwargs): - """ - I hate this design. Never pretend it as a nn.Module - """ - raise NotImplementedError - - def train_dataloader(self): - # train_dataset = torch.utils.data.Subset( - # self.dataset, - # self.dataset.split_ix['train'] - # ) - - # train_loader = torch.utils.data.DataLoader( - # dataset=train_dataset, - # batch_size=self.batch_size, - # shuffle=True, - # num_workers=4, - # collate_fn=self.dataset.collate_func - # ) - - train_dataset = COCORetrievalDataset( - split='karpathy_train', mode='train', - args=opt, - verbose=verbose - ) - - train_loader = torch.utils.data.DataLoader( - dataset=train_dataset, - batch_size=opt.batch_size, - shuffle=True, - num_workers=4, - collate_fn=train_dataset.collate_fn - ) - - return train_loader - - def val_dataloader(self, split='karpathy_val'): - # val_dataset = torch.utils.data.Subset( - # self.dataset, - # self.dataset.split_ix[split] - # ) - # val_loader = torch.utils.data.DataLoader( - # val_dataset, - # batch_size=self.batch_size, - # shuffle=False, - # num_workers=4, - # drop_last=False, - # collate_fn=self.dataset.collate_func - # ) - - val_dataset = COCORetrievalDataset( - split=split, mode='val', - args=opt, - verbose=verbose - ) - - val_loader = torch.utils.data.DataLoader( - dataset=val_dataset, - batch_size=opt.valid_batch_size, - shuffle=False, - num_workers=4, - drop_last=False, - collate_fn=val_dataset.collate_fn - ) - - return val_loader - - def test_dataloader(self): - - return self.val_dataloader('karpathy_test') - - def training_step(self, data, batch_idx): - - - batch = data - self.model.train() - - model_out = self.model.train_step( - img_feat=batch['img_feats'], - text=batch['text'], - neg_text=batch['neg_text'], - ) - - clip_loss = model_out['clip_loss'] - - if self.opt.joint_out: - loss = clip_loss - else: - grammar_loss = model_out['grammar_loss'] - loss = clip_loss + grammar_loss - - - data_time = self.trainer.profiler.recorded_durations["get_train_batch"][-1] - data_time = torch.tensor(data_time) - - # print('batch_idx', batch_idx) - # print('loss:', loss) - - # logger_logs = model_out.copy() - logger_logs = {} - - logger_logs['loss'] = loss.detach() - - logger_logs['clip_loss'] = clip_loss.detach() - - if not self.opt.joint_out: - logger_logs['grammar_loss'] = grammar_loss.detach() - - logger_logs['data_time'] = data_time.detach() - - # UserWarning: The {progress_bar:dict keyword} was deprecated in 0.9.1 and will be removed in 1.0.0 - # Please use self.log(...) inside the lightningModule instead. - - # # log on a step or aggregate epoch metric to the logger and/or progress bar - # # (inside LightningModule) - # self.log('train_loss', loss, on_step=True, on_epoch=True, prog_bar=True) - # warnings.warn(*args, **kwargs) - # UserWarning: The {log:dict keyword} was deprecated in 0.9.1 and will be removed in 1.0.0 - # Please use self.log(...) inside the lightningModule instead. - - # output = { - # 'loss': loss, - # 'log': logger_logs, - # 'progress_bar': {'data_time': data_time} - # } - - for k, v in logger_logs.items(): - if k in ['data_time', 'clip_loss', 'grammar_loss']: - self.log('train/'+k, v, prog_bar=True) - else: - self.log('train/'+k, v) - - # print('training step logged') - - return loss - - def validation_step(self, data, batch_idx): - - batch = data - self.model.eval() - - with torch.no_grad(): - model_out = self.model.train_step( - img_feat=batch['img_feats'], - text=batch['text'], - neg_text=batch['neg_text'], - ) - - if self.opt.joint_out: - clip_loss = model_out['clip_loss'] - loss = clip_loss - - output = { - # 'val_loss': loss, - 'loss': loss.detach(), - 'clip_loss': clip_loss.detach(), - # 'grammar_loss': grammar_loss.detach(), - - 'img_feat': model_out['img_feat'].detach(), - 'text_feat': model_out['text_feat'].detach(), - # 'neg_text_feat': model_out['neg_text_feat'].detach(), - # 'grammar_pos_pred': model_out['grammar_pos_pred'].detach(), - # 'grammar_neg_pred': model_out['grammar_neg_pred'].detach(), - # 'predictions': predictions, - # 'n_predictions': n_predictions, - } - else: - clip_loss = model_out['clip_loss'] - grammar_loss = model_out['grammar_loss'] - loss = clip_loss + grammar_loss - - output = { - # 'val_loss': loss, - 'loss': loss.detach(), - 'clip_loss': clip_loss.detach(), - 'grammar_loss': grammar_loss.detach(), - - 'img_feat': model_out['img_feat'].detach(), - 'text_feat': model_out['text_feat'].detach(), - # 'neg_text_feat': model_out['neg_text_feat'].detach(), - 'grammar_pos_pred': model_out['grammar_pos_pred'].detach(), - 'grammar_neg_pred': model_out['grammar_neg_pred'].detach(), - # 'predictions': predictions, - # 'n_predictions': n_predictions, - } - return output - - def test_step(self, *args, **kwargs): - return self.validation_step(*args, **kwargs) - - def validation_epoch_end(self, outputs, split='val'): - outputs = d2comm.gather(outputs) - # master node - if d2comm.is_main_process(): - assert self.trainer.node_rank == 0 and self.trainer.local_rank == 0 - outputs = sum(outputs, []) - - out = {} - - val_loss_mean = sum([_['loss'].cpu() for _ in outputs]) / len(outputs) - val_clip_loss_mean = sum([_['clip_loss'].cpu() for _ in outputs]) / len(outputs) - if not self.opt.joint_out: - val_grammar_loss_mean = sum([_['grammar_loss'].cpu() for _ in outputs]) / len(outputs) - - print('loss', val_loss_mean.item()) - print('clip_loss', val_clip_loss_mean.item()) - if not self.opt.joint_out: - print('grammar_loss', val_grammar_loss_mean.item()) - - logit_scale = self.model.clip_model.logit_scale.exp().cpu() - - text_feats = torch.cat([_['text_feat'].cpu() for _ in outputs], dim=0) - img_feats = torch.cat([_['img_feat'].cpu() for _ in outputs], dim=0) - - assert text_feats.size() == (5000, 512), text_feats.size() - assert img_feats.size() == (5000, 512), img_feats.size() - - logits_per_text = torch.matmul(text_feats, img_feats.t()) * logit_scale - logits_per_image = logits_per_text.T - - # text-to-image retrieval - print('Text-to-Image retrieval') - for k in [1, 5, 10]: - text_to_image_topk = logits_per_text.topk(k, dim=1).indices - - n_text = len(text_to_image_topk) - - labels = torch.arange(0, n_text).view(-1, 1) - - n_retrieved = ((text_to_image_topk == labels).sum(dim=1) > 0).sum() - - recall_k = n_retrieved / n_text * 100 - - out[f'text_to_image_recall_{k}'] = recall_k.item() - - print(f'R@{k}: {recall_k.item():.2f}%') - - # image-to-text retrieval - print('Image-to-Text retrieval') - for k in [1, 5, 10]: - image_to_text_topk = logits_per_image.topk(k, dim=1).indices - - n_image = len(image_to_text_topk) - - labels = torch.arange(0, n_image).view(-1, 1) - - n_retrieved = ((image_to_text_topk == labels).sum(dim=1) > 0).sum() - - recall_k = n_retrieved / n_image * 100 - - out[f'image_to_text_recall_{k}'] = recall_k.item() - - print(f'R@{k}: {recall_k.item():.2f}%') - - out.update({ - 'loss': val_loss_mean.item(), - 'clip_loss': val_clip_loss_mean.item() - }) - - if not self.opt.joint_out: - # grammar scoring - grammar_pos_pred = torch.cat([_['grammar_pos_pred'].cpu() for _ in outputs], dim=0) - grammar_neg_pred = torch.cat([_['grammar_neg_pred'].cpu() for _ in outputs], dim=0) - - TP = (grammar_pos_pred == 1).sum().item() - FP = (grammar_pos_pred == 0).sum().item() - FN = (grammar_neg_pred == 1).sum().item() - TN = (grammar_neg_pred == 0).sum().item() - print('Grammar check') - print(f'TP: {TP} FP: {FP} FN: {FN} TN: {TN}') - - precision = TP / (TP + FP) * 100 - recall = TP / (TP + FN) * 100 - accuracy = (TP + TN) / (TP + FP + FN + TN) * 100 - f1 = 2 * precision * recall / (precision + recall) - print(f'Precision: {precision:.2f}%') - print(f'Recall: {recall:.2f}%') - print(f'Accuracy: {accuracy:.2f}%') - print(f'F1: {f1:.2f}%') - print('Total: {}'.format(len(grammar_pos_pred))) - - out.update({ - 'grammar_loss': val_grammar_loss_mean, - - 'grammar_precision': precision, - 'grammar_recall': recall, - 'grammar_accuracy': accuracy, - 'grammar_f1': f1, - - }) - - else: - out = {} - - out = d2comm.all_gather(out)[0] # Only the one from master node - assert len(out) > 0 # make sure the head has index 0 - - # must all be tensors - out = {k: torch.tensor(v) if not torch.is_tensor( - v) else v for k, v in out.items()} - - for k, v in out.items(): - self.log(f'{split}/{k}', v) - - def test_epoch_end(self, outputs): - - self.validation_epoch_end(outputs, 'test') - - def configure_optimizers(self): - # opt = self.opt - # model = self.model - - # parameters = [p for p in model.parameters() if p.requires_grad] - - # if opt.noamopt: - # # assert opt.caption_model in ['transformer', 'bert', 'm2transformer'], 'noamopt can only work with transformer' - # optimizer = utils.get_std_opt( - # model, optim_func=opt.optim, factor=opt.noamopt_factor, warmup=opt.noamopt_warmup) - # elif opt.reduce_on_plateau: - # # optimizer = utils.build_optimizer(model.parameters(), opt) - # optimizer = utils.build_optimizer(parameters, opt) - # optimizer = utils.ReduceLROnPlateau(optimizer, - # factor=opt.reduce_on_plateau_factor, - # patience=opt.reduce_on_plateau_patience) - # else: - # # optimizer = utils.build_optimizer(model.parameters(), opt) - # optimizer = utils.build_optimizer(parameters, opt) - - - # from transformers.optimization import AdamW, get_linear_schedule_with_warmup - # batch_per_epoch = len(self.train_loader) - # t_total = batch_per_epoch // self.args.gradient_accumulation_steps * self.args.epochs - # warmup_ratio = self.args.warmup_ratio - # warmup_iters = int(t_total * warmup_ratio) - # if self.verbose: - # print("Batch per epoch: %d" % batch_per_epoch) - # print("Total Iters: %d" % t_total) - # print('Warmup ratio:', warmup_ratio) - # print("Warm up Iters: %d" % warmup_iters) - - if self.args.optim == 'adamw': - no_decay = ["bias", "LayerNorm.weight"] - optimizer_grouped_parameters = [ - { - "params": [p for n, p in self.model.named_parameters() if not any(nd in n for nd in no_decay)], - "weight_decay": self.args.weight_decay, - }, - { - "params": [p for n, p in self.model.named_parameters() if any(nd in n for nd in no_decay)], - "weight_decay": 0.0, - }, - ] - - for group in optimizer_grouped_parameters: - group['params'] = [p for p in group['params'] if p.requires_grad] - - from transformers.optimization import AdamW - optim = AdamW(optimizer_grouped_parameters, - lr=self.args.lr, eps=self.args.adam_eps) - # lr_scheduler = get_linear_schedule_with_warmup( - # optim, warmup_iters, t_total) - - # optimizers = [] - optimizers = [optim] - lr_schedulers = [] - - return optimizers, lr_schedulers - - def optimizer_step(self, epoch, batch_idx, optimizer, - optimizer_idx, *args, **kwargs): - # # warm up lr - # opt = self.opt - # iteration = self.trainer.global_step - # if opt.use_warmup and (iteration < opt.noamopt_warmup): - # opt.current_lr = opt.learning_rate * \ - # (iteration+1) / opt.noamopt_warmup - # utils.set_lr(optimizer, opt.current_lr) - - super().optimizer_step(epoch, batch_idx, optimizer, - optimizer_idx, *args, **kwargs) - - # print('optimizer step') - - def state_dict(self): - """ - Save the model state dict as well as opt and vocab - """ - state_dict = self.model.state_dict() - device = next(iter(state_dict.values())).device - assert '_vocab' not in state_dict and '_opt' not in state_dict, 'Just in case' - # state_dict.update({ - # '_vocab': utils.serialize_to_tensor(self.model.vocab).to(device), - # '_opt': utils.serialize_to_tensor(self.opt).to(device) - # }) - return state_dict - - def load_state_dict(self, state_dict=None, strict=True): - # if '_vocab' in state_dict: - # self.model.vocab = utils.deserialize(state_dict['_vocab']) - # del state_dict['_vocab'] - # elif strict: - # raise KeyError - # if '_opt' in state_dict: - # saved_model_opt = utils.deserialize(state_dict['_opt']) - # del state_dict['_opt'] - # opt = self.opt - # # Make sure the saved opt is compatible with the curren topt - # need_be_same = ["caption_model", - # "rnn_type", "rnn_size", "num_layers"] - # for checkme in need_be_same: - # if getattr(saved_model_opt, checkme) in ['updown', 'topdown'] and \ - # getattr(opt, checkme) in ['updown', 'topdown']: - # continue - # assert getattr(saved_model_opt, checkme) == getattr( - # opt, checkme), "Command line argument and saved model disagree on '%s' " % checkme - # elif strict: - # raise KeyError - self.model.load_state_dict(state_dict, strict) - - -class OnEpochStartCallback(pl.Callback): - - def on_epoch_start(self, trainer, pl_module): - # Update lr/training stage/scheduled sampling prob etc. - opt = pl_module.opt - model = pl_module.model - epoch = trainer.current_epoch - optimizer = trainer.optimizers[0] - - # if not opt.noamopt and not opt.reduce_on_plateau: - # # Assign the learning rate - # if epoch > opt.learning_rate_decay_start and opt.learning_rate_decay_start >= 0: - # frac = ( - # epoch - opt.learning_rate_decay_start) // opt.learning_rate_decay_every - # decay_factor = opt.learning_rate_decay_rate ** frac - # opt.current_lr = opt.learning_rate * decay_factor - # else: - # opt.current_lr = opt.learning_rate - # utils.set_lr(optimizer, opt.current_lr) # set the decayed rate - # # Assign the scheduled sampling prob - # if epoch > opt.scheduled_sampling_start and opt.scheduled_sampling_start >= 0: - # frac = ( - # epoch - opt.scheduled_sampling_start) // opt.scheduled_sampling_increase_every - # opt.ss_prob = min(opt.scheduled_sampling_increase_prob * - # frac, opt.scheduled_sampling_max_prob) - # model.ss_prob = opt.ss_prob - - # # If start self critical training - # if opt.self_critical_after != -1 and epoch >= opt.self_critical_after: - # sc_flag = True - # init_scorer(opt.cached_tokens) - # else: - # sc_flag = False - - # # If start structure loss training - # if opt.structure_after != -1 and epoch >= opt.structure_after: - # struc_flag = True - # init_scorer(opt.cached_tokens) - # else: - # struc_flag = False - - # pl_module.struc_flag = struc_flag - # pl_module.sc_flag = sc_flag - - -class ModelCheckpoint(pl.callbacks.ModelCheckpoint): - - def on_keyboard_interrupt(self, trainer, pl_module): - # Save model when keyboard interrupt - filepath = os.path.join(self.dirpath, self.prefix + 'interrupt.ckpt') - self._save_model(filepath) - -from param import parse_args -# opt = opts.parse_opt() -args = parse_args() -opt = args - -checkpoint_callback = ModelCheckpoint( - filepath=opt.checkpoint_dir + '{epoch:02d}', - # dirpath=opt.checkpoint_path, - save_last=True, - save_top_k=1, - verbose=True, - # monitor='to_monitor', - # monitor='val/to_monitor', - # monitor='val/CIDEr', - monitor='val/loss', - mode='min', - # prefix=opt.id+'_', - prefix=opt.id, - # filename=f'{opt.id}_', -) - -verbose = True -# import torch -# if torch.cuda.current_device() in [0, -1]: -if 'LOCAL_RANK' in os.environ and os.environ['LOCAL_RANK'] != '0': - verbose = False - -# if verbose: -# print(opt) -# print(""" -# val_image_use, -# save_checkpoint_very -# save_every_epoch, -# save_history-ckpt will be ignored. -# """) - -# Lightning defines batch size as batch size per gpu -assert opt.batch_size % torch.cuda.device_count() == 0 -opt.batch_size = opt.batch_size // torch.cuda.device_count() -opt.valid_batch_size = opt.valid_batch_size // torch.cuda.device_count() - -# If resume from last checkpoint -# if opt.start_from is not None and os.path.isfile(os.path.join(opt.start_from, f'{opt.id}_last.ckpt')): -# resume_from = os.path.join(opt.start_from, f'{opt.id}_last.ckpt') -if opt.start_from is not None and os.path.isfile(os.path.join(opt.start_from, f'{opt.id}-last.ckpt')): - resume_from = os.path.join(opt.start_from, f'{opt.id}-last.ckpt') - if verbose: - print('resume from', resume_from) -else: - resume_from = None - -from pytorch_lightning.loggers import WandbLogger -wandb_logger = WandbLogger( - # project='CLIP-ViL-COCOCaption', - project='CLIP-Finetune-COCO', - name=opt.id, -) - -if verbose: - wandb_logger.experiment.config.update(opt) - from pathlib import Path - import glob - import wandb - # src_dir = Path(__file__).resolve().parent.parent - glob_str = "*.py" - base_path = './' - wandb.save(glob_str=glob_str, base_path=base_path) - - glob_str = "**/*.yaml" - base_path = './' - wandb.save(glob_str=glob_str, base_path=base_path) - - # code = wandb.Artifact('project-source', type='code') - # for path in glob.glob('**/*.py', recursive=True): - # code.add_file(path, name='source/'+path) - # print(path) - # wandb.run.use_artifact(code) - - - - -lit = LitModel(opt) -# warning grad_clip_mode is ignored. -trainer = pl.Trainer( - callbacks=[ - OnEpochStartCallback(), - # pl.callbacks.lr_logger.LearningRateLogger() - pl.callbacks.LearningRateMonitor() - ], - default_root_dir=opt.checkpoint_dir, - resume_from_checkpoint=resume_from, - - distributed_backend='ddp', - gpus=torch.cuda.device_count(), - - # gpus=1, - - check_val_every_n_epoch=1, - # max_epochs=opt.max_epochs, - max_epochs=opt.epochs, - # gradient_clip_val=opt.grad_clip_value, - gradient_clip_val=opt.clip_grad_norm, - - checkpoint_callback=checkpoint_callback, - log_gpu_memory='min_max', - # log_save_interval=opt.losses_log_every, - log_every_n_steps=opt.losses_log_every, - profiler=True, - # profiler='simple', - # row_log_interval=10, # what is it? - flush_logs_every_n_steps=10, - num_sanity_val_steps=0, - # val_check_interval=0.01, - # limit_train_batches=500, - # progress_bar_refresh_rate=0, - # fast_dev_run=True, - precision=opt.precision, - logger=wandb_logger -) - -if os.getenv('EVALUATE', '0') == '1': - trainer.test(lit) -else: - trainer.fit(lit) diff --git a/save/README.md b/save/README.md deleted file mode 100644 index 91547b46ffedc91d209fec4c7ac0b8cfb9e447de..0000000000000000000000000000000000000000 --- a/save/README.md +++ /dev/null @@ -1 +0,0 @@ -Directory for checkpoints \ No newline at end of file diff --git a/scripts/build_bpe_subword_nmt.py b/scripts/build_bpe_subword_nmt.py deleted file mode 100644 index bdf5dfa17f18f06c285edb17500b67301c1143dd..0000000000000000000000000000000000000000 --- a/scripts/build_bpe_subword_nmt.py +++ /dev/null @@ -1,214 +0,0 @@ -""" -Preprocess a raw json dataset into hdf5/json files for use in data_loader.lua - -Input: json file that has the form -[{ file_path: 'path/img.jpg', captions: ['a caption', ...] }, ...] -example element in this list would look like -{'captions': [u'A man with a red helmet on a small moped on a dirt road. ', u'Man riding a motor bike on a dirt road on the countryside.', u'A man riding on the back of a motorcycle.', u'A dirt path with a young person on a motor bike rests to the foreground of a verdant area with a bridge and a background of cloud-wreathed mountains. ', u'A man in a red shirt and a red hat is on a motorcycle on a hill side.'], 'file_path': u'val2014/COCO_val2014_000000391895.jpg', 'id': 391895} - -This script reads this json, does some basic preprocessing on the captions -(e.g. lowercase, etc.), creates a special UNK token, and encodes everything to arrays - -Output: a json file and an hdf5 file -The hdf5 file contains several fields: -/labels is (M,max_length) uint32 array of encoded labels, zero padded -/label_start_ix and /label_end_ix are (N,) uint32 arrays of pointers to the - first and last indices (in range 1..M) of labels for each image -/label_length stores the length of the sequence for each of the M sequences - -The json file has a dict that contains: -- an 'ix_to_word' field storing the vocab in form {ix:'word'}, where ix is 1-indexed -- an 'images' field that is a list holding auxiliary information for each image, - such as in particular the 'split' it was assigned to. -""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import os -import json -import argparse -from random import shuffle, seed -import string -# non-standard dependencies: -import h5py -import numpy as np -import torch -import torchvision.models as models -import skimage.io -from PIL import Image - -import codecs -import tempfile -from subword_nmt import learn_bpe, apply_bpe - -# python scripts/build_bpe_subword_nmt.py --input_json data/dataset_coco.json --output_json data/cocotalkbpe.json --output_h5 data/cocotalkbpe - -def build_vocab(imgs, params): - # count up the number of words - captions = [] - for img in imgs: - for sent in img['sentences']: - captions.append(' '.join(sent['tokens'])) - captions='\n'.join(captions) - all_captions = tempfile.NamedTemporaryFile(delete=False) - all_captions.close() - with open(all_captions.name, 'w') as txt_file: - txt_file.write(captions) - - # - codecs_output = tempfile.NamedTemporaryFile(delete=False) - codecs_output.close() - with codecs.open(codecs_output.name, 'w', encoding='UTF-8') as output: - learn_bpe.learn_bpe(codecs.open(all_captions.name, encoding='UTF-8'), output, params['symbol_count']) - - with codecs.open(codecs_output.name, encoding='UTF-8') as codes: - bpe = apply_bpe.BPE(codes) - - tmp = tempfile.NamedTemporaryFile(delete=False) - tmp.close() - - tmpout = codecs.open(tmp.name, 'w', encoding='UTF-8') - - for _, img in enumerate(imgs): - img['final_captions'] = [] - for sent in img['sentences']: - txt = ' '.join(sent['tokens']) - txt = bpe.segment(txt).strip() - img['final_captions'].append(txt.split(' ')) - tmpout.write(txt) - tmpout.write('\n') - if _ < 20: - print(txt) - - tmpout.close() - tmpin = codecs.open(tmp.name, encoding='UTF-8') - - vocab = learn_bpe.get_vocabulary(tmpin) - vocab = sorted(vocab.keys(), key=lambda x: vocab[x], reverse=True) - - # Always insert UNK - print('inserting the special UNK token') - vocab.append('UNK') - - print('Vocab size:', len(vocab)) - - os.remove(all_captions.name) - with open(codecs_output.name, 'r') as codes: - bpe = codes.read() - os.remove(codecs_output.name) - os.remove(tmp.name) - - return vocab, bpe - -def encode_captions(imgs, params, wtoi): - """ - encode all captions into one large array, which will be 1-indexed. - also produces label_start_ix and label_end_ix which store 1-indexed - and inclusive (Lua-style) pointers to the first and last caption for - each image in the dataset. - """ - - max_length = params['max_length'] - N = len(imgs) - M = sum(len(img['final_captions']) for img in imgs) # total number of captions - - label_arrays = [] - label_start_ix = np.zeros(N, dtype='uint32') # note: these will be one-indexed - label_end_ix = np.zeros(N, dtype='uint32') - label_length = np.zeros(M, dtype='uint32') - caption_counter = 0 - counter = 1 - for i,img in enumerate(imgs): - n = len(img['final_captions']) - assert n > 0, 'error: some image has no captions' - - Li = np.zeros((n, max_length), dtype='uint32') - for j,s in enumerate(img['final_captions']): - label_length[caption_counter] = min(max_length, len(s)) # record the length of this sequence - caption_counter += 1 - for k,w in enumerate(s): - if k < max_length: - Li[j,k] = wtoi[w] - - # note: word indices are 1-indexed, and captions are padded with zeros - label_arrays.append(Li) - label_start_ix[i] = counter - label_end_ix[i] = counter + n - 1 - - counter += n - - L = np.concatenate(label_arrays, axis=0) # put all the labels together - assert L.shape[0] == M, 'lengths don\'t match? that\'s weird' - assert np.all(label_length > 0), 'error: some caption had no words?' - - print('encoded captions to array of size ', L.shape) - return L, label_start_ix, label_end_ix, label_length - -def main(params): - - imgs = json.load(open(params['input_json'], 'r')) - imgs = imgs['images'] - - seed(123) # make reproducible - - # create the vocab - vocab, bpe = build_vocab(imgs, params) - itow = {i+1:w for i,w in enumerate(vocab)} # a 1-indexed vocab translation table - wtoi = {w:i+1 for i,w in enumerate(vocab)} # inverse table - - # encode captions in large arrays, ready to ship to hdf5 file - L, label_start_ix, label_end_ix, label_length = encode_captions(imgs, params, wtoi) - - # create output h5 file - N = len(imgs) - f_lb = h5py.File(params['output_h5']+'_label.h5', "w") - f_lb.create_dataset("labels", dtype='uint32', data=L) - f_lb.create_dataset("label_start_ix", dtype='uint32', data=label_start_ix) - f_lb.create_dataset("label_end_ix", dtype='uint32', data=label_end_ix) - f_lb.create_dataset("label_length", dtype='uint32', data=label_length) - f_lb.close() - - # create output json file - out = {} - out['ix_to_word'] = itow # encode the (1-indexed) vocab - out['images'] = [] - out['bpe'] = bpe - for i,img in enumerate(imgs): - - jimg = {} - jimg['split'] = img['split'] - if 'filename' in img: jimg['file_path'] = os.path.join(img['filepath'], img['filename']) # copy it over, might need - if 'cocoid' in img: jimg['id'] = img['cocoid'] # copy over & mantain an id, if present (e.g. coco ids, useful) - - if params['images_root'] != '': - with Image.open(os.path.join(params['images_root'], img['filepath'], img['filename'])) as _img: - jimg['width'], jimg['height'] = _img.size - - out['images'].append(jimg) - - json.dump(out, open(params['output_json'], 'w')) - print('wrote ', params['output_json']) - -if __name__ == "__main__": - - parser = argparse.ArgumentParser() - - # input json - parser.add_argument('--input_json', required=True, help='input json file to process into hdf5') - parser.add_argument('--output_json', default='data.json', help='output json file') - parser.add_argument('--output_h5', default='data', help='output h5 file') - parser.add_argument('--images_root', default='', help='root location in which images are stored, to be prepended to file_path in input json') - - # options - parser.add_argument('--max_length', default=16, type=int, help='max length of a caption, in number of words. captions longer than this get clipped.') - parser.add_argument('--symbol_count', default=10000, type=int, help='only words that occur more than this number of times will be put in vocab') - - args = parser.parse_args() - params = vars(args) # convert to ordinary dict - print('parsed input parameters:') - print(json.dumps(params, indent = 2)) - main(params) - - diff --git a/scripts/clip_prepro_feats.py b/scripts/clip_prepro_feats.py deleted file mode 100644 index b7a45c829fa5c19e36509170135835c6d6bc8d67..0000000000000000000000000000000000000000 --- a/scripts/clip_prepro_feats.py +++ /dev/null @@ -1,170 +0,0 @@ -""" -Preprocess a raw json dataset into features files for use in data_loader.py - -Input: json file that has the form -[{ file_path: 'path/img.jpg', captions: ['a caption', ...] }, ...] -example element in this list would look like -{'captions': [u'A man with a red helmet on a small moped on a dirt road. ', u'Man riding a motor bike on a dirt road on the countryside.', u'A man riding on the back of a motorcycle.', u'A dirt path with a young person on a motor bike rests to the foreground of a verdant area with a bridge and a background of cloud-wreathed mountains. ', u'A man in a red shirt and a red hat is on a motorcycle on a hill side.'], 'file_path': u'val2014/COCO_val2014_000000391895.jpg', 'id': 391895} - -This script reads this json, does some basic preprocessing on the captions -(e.g. lowercase, etc.), creates a special UNK token, and encodes everything to arrays - -Output: two folders of features -""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import os -import json -import argparse -from random import shuffle, seed -import string -# non-standard dependencies: -import h5py -from six.moves import cPickle -import numpy as np -import torch -import torchvision.models as models -import skimage.io - -from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize -from PIL import Image -from torch import nn - -preprocess = Compose([ - Resize((448, 448), interpolation=Image.BICUBIC), - CenterCrop((448, 448)), - ToTensor() -]) - - -from clip.clip import load -from timm.models.vision_transformer import resize_pos_embed -import timm - -from captioning.utils.resnet_utils import myResnet -import captioning.utils.resnet as resnet - -from tqdm import tqdm - - -def main(params): - if params["model_type"] != 'vit_base_patch32_224_in21k': - model, transform = load(params["model_type"], jit=False) - else: - model = timm.create_model(params["model_type"], pretrained=True) - model = model.cuda() - - if params["model_type"] != 'vit_base_patch32_224_in21k': - save_model_type = params["model_type"].split("-")[0] - mean = torch.Tensor([0.48145466, 0.4578275, 0.40821073]).to("cuda").reshape(3, 1, 1) - std = torch.Tensor([0.26862954, 0.26130258, 0.27577711]).to("cuda").reshape(3, 1, 1) - - if "RN" in params["model_type"]: - num_patches = 196 #600 * 1000 // 32 // 32 - pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, model.visual.attnpool.positional_embedding.shape[-1], device='cuda'),) - pos_embed.weight = resize_pos_embed(model.visual.attnpool.positional_embedding.unsqueeze(0), pos_embed) - model.visual.attnpool.positional_embedding = pos_embed - - else: - save_model_type = 'vit_base' - mean = torch.Tensor([0.5, 0.5, 0.5]).to("cuda").reshape(3, 1, 1) - std = torch.Tensor([0.5, 0.5, 0.5]).to("cuda").reshape(3, 1, 1) - - num_patches = 196 #600 * 1000 // 32 // 32 - pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, 768, device='cuda'),) - pos_embed.weight = resize_pos_embed(model.pos_embed, pos_embed) - model.pos_embed = pos_embed - - if params["model_type"] == "ViT-B/32": - num_patches = 196 #600 * 1000 // 32 // 32 - pos_embed = nn.Parameter(torch.zeros(num_patches + 1, 768, device='cuda'),) - pos_embed.weight = resize_pos_embed(model.visual.positional_embedding.unsqueeze(0), pos_embed.unsqueeze(0)) - model.visual.positional_embedding = pos_embed - imgs = json.load(open(params['input_json'], 'r')) - - imgs = imgs['images'] - - if args.n_jobs > 1: - print('Total imgs:', len(imgs)) - print('Using {} jobs'.format(args.n_jobs)) - print('job id:', args.job_id) - imgs = imgs[args.job_id::args.n_jobs] - - N = len(imgs) - - seed(123) # make reproducible - - dir_fc = params['output_dir']+'_clip_'+save_model_type+'_fc' - dir_att = params['output_dir']+'_clip_'+save_model_type+'_att' - if not os.path.isdir(dir_fc): - os.mkdir(dir_fc) - if not os.path.isdir(dir_att): - os.mkdir(dir_att) - - for i,img in enumerate(tqdm(imgs)): - # load the image - with torch.no_grad(): - - image = preprocess(Image.open(os.path.join(params['images_root'], img['filepath'], img['filename']) ).convert("RGB")) - image = torch.tensor(np.stack([image])).cuda() - image -= mean - image /= std - if "RN" in params["model_type"]: - tmp_att, tmp_fc = model.encode_image(image) - tmp_att = tmp_att[0].permute(1, 2, 0) - tmp_fc = tmp_fc[0] - elif params["model_type"] == 'vit_base_patch32_224_in21k': - x = model(image) - tmp_fc = x[0, 0, :] - tmp_att = x[0, 1:, :].reshape( 14, 14, 768 ) - else: - x = model.visual.conv1(image.half()) # shape = [*, width, grid, grid] - x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2] - x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] - x = torch.cat([model.visual.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width] - x = x + model.visual.positional_embedding.to(x.dtype)[:x.shape[1], :] - x = model.visual.ln_pre(x) - - x = x.permute(1, 0, 2) # NLD -> LND - - for layer_idx, layer in enumerate(model.visual.transformer.resblocks): - x = layer(x) - - x = x.permute(1, 0, 2) - tmp_fc = x[0, 0, :] - tmp_att = x[0, 1:, :].reshape( 14, 14, 768 ) - - np.save(os.path.join(dir_fc, str(img['cocoid'])), tmp_fc.data.cpu().float().numpy()) - np.savez_compressed(os.path.join(dir_att, str(img['cocoid'])), feat=tmp_att.data.cpu().float().numpy()) - - - # if i % 1000 == 0: - # print('processing %d/%d (%.2f%% done)' % (i, N, i*100.0/N)) - print('wrote ', dir_fc, dir_att) - -if __name__ == "__main__": - - parser = argparse.ArgumentParser() - - # input json - parser.add_argument('--input_json', required=True, help='input json file to process into hdf5') - parser.add_argument('--output_dir', default='data', help='output h5 file') - - # options - parser.add_argument('--images_root', default='', help='root location in which images are stored, to be prepended to file_path in input json') - parser.add_argument('--att_size', default=14, type=int, help='14x14 or 7x7') - parser.add_argument('--model_type', default='RN50', type=str, help='RN50, RN101, RN50x4, ViT-B/32, vit_base_patch32_224_in21k') - - parser.add_argument('--n_jobs', default=-1, type=int, help='number of jobs to run in parallel') - parser.add_argument('--job_id', default=0, type=int, help='job id') - parser.add_argument('--batch_size', default=1, type=int, help='batch size') - - - args = parser.parse_args() - params = vars(args) # convert to ordinary dict - print('parsed input parameters:') - print(json.dumps(params, indent = 2)) - main(params) diff --git a/scripts/clipscore_prepro_feats.py b/scripts/clipscore_prepro_feats.py deleted file mode 100644 index 72df6a02e55c3828dbc043fa272808acf1ee9f7e..0000000000000000000000000000000000000000 --- a/scripts/clipscore_prepro_feats.py +++ /dev/null @@ -1,162 +0,0 @@ -""" -Preprocess a raw json dataset into features files for use in data_loader.py - -Input: json file that has the form -[{ file_path: 'path/img.jpg', captions: ['a caption', ...] }, ...] -example element in this list would look like -{'captions': [u'A man with a red helmet on a small moped on a dirt road. ', u'Man riding a motor bike on a dirt road on the countryside.', u'A man riding on the back of a motorcycle.', u'A dirt path with a young person on a motor bike rests to the foreground of a verdant area with a bridge and a background of cloud-wreathed mountains. ', u'A man in a red shirt and a red hat is on a motorcycle on a hill side.'], 'file_path': u'val2014/COCO_val2014_000000391895.jpg', 'id': 391895} - -This script reads this json, does some basic preprocessing on the captions -(e.g. lowercase, etc.), creates a special UNK token, and encodes everything to arrays - -Output: two folders of features -""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import os -import json -import argparse -from random import shuffle, seed -import string -# non-standard dependencies: -import h5py -from six.moves import cPickle -import numpy as np -import torch -import torchvision.models as models -import skimage.io - -from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize -from PIL import Image -from torch import nn - -# preprocess = Compose([ -# Resize((448, 448), interpolation=Image.BICUBIC), -# CenterCrop((448, 448)), -# ToTensor() -# ]) - - -# from clip.clip import load -# from timm.models.vision_transformer import resize_pos_embed -# import timm - -# from captioning.utils.resnet_utils import myResnet -# import captioning.utils.resnet as resnet - -from captioning.utils.clipscore import CLIPScore - -from tqdm import tqdm - - - -def main(params): - - clipscore_model = CLIPScore() - clipscore_model.to('cuda') - - imgs = json.load(open(params['input_json'], 'r')) - imgs = imgs['images'] - - if args.n_jobs > 1: - print('Total imgs:', len(imgs)) - print('Using {} jobs'.format(args.n_jobs)) - print('job id:', args.job_id) - imgs = imgs[args.job_id::args.n_jobs] - - N = len(imgs) - - seed(123) # make reproducible - - # dir_fc = params['output_dir']+'_clip_'+save_model_type+'_fc' - # dir_att = params['output_dir']+'_clip_'+save_model_type+'_att' - - vis_dir_fc = params['output_dir']+'_clipscore_vis' - if not os.path.isdir(vis_dir_fc): - os.mkdir(vis_dir_fc) - - # text_dir_fc = params['output_dir']+'_clipscore_text' - # if not os.path.isdir(text_dir_fc): - # os.mkdir(text_dir_fc) - - # if not os.path.isdir(dir_att): - # os.mkdir(dir_att) - - for i, img in enumerate(tqdm(imgs)): - # load the image - - img_path = os.path.join(params['images_root'], img['filepath'], img['filename']) - img_feat = clipscore_model.image_extract(img_path) - img_feat = img_feat.view(512) - - # for d in img['sentences']: - # text = d['raw'].strip() - # text_feat = clipscore_model.text_extract(text) - - - # with torch.no_grad(): - - # image = preprocess(Image.open(os.path.join(params['images_root'], img['filepath'], img['filename']) ).convert("RGB")) - # image = torch.tensor(np.stack([image])).cuda() - # image -= mean - # image /= std - # if "RN" in params["model_type"]: - # tmp_att, tmp_fc = model.encode_image(image) - # tmp_att = tmp_att[0].permute(1, 2, 0) - # tmp_fc = tmp_fc[0] - # elif params["model_type"] == 'vit_base_patch32_224_in21k': - # x = model(image) - # tmp_fc = x[0, 0, :] - # tmp_att = x[0, 1:, :].reshape( 14, 14, 768 ) - # else: - # x = model.visual.conv1(image.half()) # shape = [*, width, grid, grid] - # x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2] - # x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] - # x = torch.cat([model.visual.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width] - # x = x + model.visual.positional_embedding.to(x.dtype)[:x.shape[1], :] - # x = model.visual.ln_pre(x) - - # x = x.permute(1, 0, 2) # NLD -> LND - - # for layer_idx, layer in enumerate(model.visual.transformer.resblocks): - # x = layer(x) - - # x = x.permute(1, 0, 2) - # tmp_fc = x[0, 0, :] - # tmp_att = x[0, 1:, :].reshape( 14, 14, 768 ) - - np.save(os.path.join(vis_dir_fc, str(img['cocoid'])), img_feat.data.cpu().float().numpy()) - # np.save(os.path.join(text_dir_fc, str(img['cocoid'])), tmp_fc.data.cpu().float().numpy()) - - - # np.savez_compressed(os.path.join(dir_att, str(img['cocoid'])), feat=tmp_att.data.cpu().float().numpy()) - - if i % 1000 == 0: - print('processing %d/%d (%.2f%% done)' % (i, N, i*100.0/N)) - print('wrote ', vis_dir_fc) - -if __name__ == "__main__": - - parser = argparse.ArgumentParser() - - # input json - # dataset_coco.json - parser.add_argument('--input_json', required=True, help='input json file to process into hdf5') - parser.add_argument('--output_dir', default='data', help='output h5 file') - - # options - parser.add_argument('--images_root', default='', help='root location in which images are stored, to be prepended to file_path in input json') - # parser.add_argument('--att_size', default=14, type=int, help='14x14 or 7x7') - # parser.add_argument('--model_type', default='RN50', type=str, help='RN50, RN101, RN50x4, ViT-B/32, vit_base_patch32_224_in21k') - - parser.add_argument('--n_jobs', default=-1, type=int, help='number of jobs to run in parallel') - parser.add_argument('--job_id', default=0, type=int, help='job id') - - args = parser.parse_args() - params = vars(args) # convert to ordinary dict - print('parsed input parameters:') - print(json.dumps(params, indent = 2)) - main(params) diff --git a/scripts/copy_model.sh b/scripts/copy_model.sh deleted file mode 100644 index 3e0f8945ffcc1aff3016812a0f5ab91465677514..0000000000000000000000000000000000000000 --- a/scripts/copy_model.sh +++ /dev/null @@ -1,9 +0,0 @@ -#!/bin/sh - -if [ ! -d log_$2 ]; then -cp -r log_$1 log_$2 -cd log_$2 -mv infos_$1-best.pkl infos_$2-best.pkl -mv infos_$1.pkl infos_$2.pkl -cd ../ -fi diff --git a/scripts/dump_to_h5df.py b/scripts/dump_to_h5df.py deleted file mode 100644 index b1d3f2c3dfea2c450ad0d1f4e71c1382b66eb095..0000000000000000000000000000000000000000 --- a/scripts/dump_to_h5df.py +++ /dev/null @@ -1,56 +0,0 @@ -import argparse -import h5py -import os -import numpy as np -import json -from tqdm import tqdm - - -def main(params): - - imgs = json.load(open(params['input_json'], 'r')) - imgs = imgs['images'] - N = len(imgs) - - if params['fc_input_dir'] is not None: - print('processing fc') - with h5py.File(params['fc_output']) as file_fc: - for i, img in enumerate(tqdm(imgs)): - npy_fc_path = os.path.join( - params['fc_input_dir'], - str(img['cocoid']) + '.npy') - - d_set_fc = file_fc.create_dataset( - str(img['cocoid']), data=np.load(npy_fc_path)) - file_fc.close() - - if params['att_input_dir'] is not None: - print('processing att') - with h5py.File(params['att_output']) as file_att: - for i, img in enumerate(tqdm(imgs)): - npy_att_path = os.path.join( - params['att_input_dir'], - str(img['cocoid']) + '.npz') - - d_set_att = file_att.create_dataset( - str(img['cocoid']), - data=np.load(npy_att_path)['feat']) - file_att.close() - - -if __name__ == "__main__": - - parser = argparse.ArgumentParser() - - parser.add_argument('--input_json', required=True, help='input json file to process into hdf5') - parser.add_argument('--fc_output', default='data', help='output h5 filename for fc') - parser.add_argument('--att_output', default='data', help='output h5 file for att') - parser.add_argument('--fc_input_dir', default=None, help='input directory for numpy fc files') - parser.add_argument('--att_input_dir', default=None, help='input directory for numpy att files') - - args = parser.parse_args() - params = vars(args) # convert to ordinary dict - print('parsed input parameters:') - print(json.dumps(params, indent=2)) - - main(params) \ No newline at end of file diff --git a/scripts/dump_to_lmdb.py b/scripts/dump_to_lmdb.py deleted file mode 100644 index 483dae7d7f2ec513968f12937a82666727ef2700..0000000000000000000000000000000000000000 --- a/scripts/dump_to_lmdb.py +++ /dev/null @@ -1,241 +0,0 @@ -# copy from https://github.com/Lyken17/Efficient-PyTorch/tools - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import os -import os.path as osp -import os, sys -import os.path as osp -from PIL import Image -import six -import string - -from lmdbdict import lmdbdict -from lmdbdict.methods import DUMPS_FUNC, LOADS_FUNC -import pickle -import tqdm -import numpy as np -import argparse -import json - -import torch -import torch.utils.data as data -from torch.utils.data import DataLoader - -import csv -csv.field_size_limit(sys.maxsize) -FIELDNAMES = ['image_id', 'status'] - -class FolderLMDB(data.Dataset): - def __init__(self, db_path, fn_list=None): - self.db_path = db_path - self.lmdb = lmdbdict(db_path, unsafe=True) - self.lmdb._key_dumps = DUMPS_FUNC['ascii'] - self.lmdb._value_loads = LOADS_FUNC['identity'] - if fn_list is not None: - self.length = len(fn_list) - self.keys = fn_list - else: - raise Error - - def __getitem__(self, index): - byteflow = self.lmdb[self.keys[index]] - - # load image - imgbuf = byteflow - buf = six.BytesIO() - buf.write(imgbuf) - buf.seek(0) - try: - if args.extension == '.npz': - feat = np.load(buf)['feat'] - else: - feat = np.load(buf) - except Exception as e: - print(self.keys[index], e) - return None - - return feat - - def __len__(self): - return self.length - - def __repr__(self): - return self.__class__.__name__ + ' (' + self.db_path + ')' - - -def make_dataset(dir, extension): - images = [] - dir = os.path.expanduser(dir) - for root, _, fnames in sorted(os.walk(dir)): - for fname in sorted(fnames): - if has_file_allowed_extension(fname, [extension]): - path = os.path.join(root, fname) - images.append(path) - - return images - - -def raw_reader(path): - with open(path, 'rb') as f: - bin_data = f.read() - return bin_data - - -def raw_npz_reader(path): - with open(path, 'rb') as f: - bin_data = f.read() - try: - npz_data = np.load(six.BytesIO(bin_data))['feat'] - except Exception as e: - print(path) - npz_data = None - print(e) - return bin_data, npz_data - - -def raw_npy_reader(path): - with open(path, 'rb') as f: - bin_data = f.read() - try: - npy_data = np.load(six.BytesIO(bin_data)) - except Exception as e: - print(path) - npy_data = None - print(e) - return bin_data, npy_data - - -class Folder(data.Dataset): - - def __init__(self, root, loader, extension, fn_list=None): - super(Folder, self).__init__() - self.root = root - if fn_list: - samples = [os.path.join(root, str(_)+extension) for _ in fn_list] - else: - samples = make_dataset(self.root, extension) - - self.loader = loader - self.extension = extension - self.samples = samples - - def __getitem__(self, index): - """ - Args: - index (int): Index - Returns: - tuple: (sample, target) where target is class_index of the target class. - """ - path = self.samples[index] - sample = self.loader(path) - - return (path.split('/')[-1].split('.')[0],) + sample - - def __len__(self): - return len(self.samples) - - -def folder2lmdb(dpath, fn_list, write_frequency=5000): - directory = osp.expanduser(osp.join(dpath)) - print("Loading dataset from %s" % directory) - if args.extension == '.npz': - dataset = Folder(directory, loader=raw_npz_reader, extension='.npz', - fn_list=fn_list) - else: - dataset = Folder(directory, loader=raw_npy_reader, extension='.npy', - fn_list=fn_list) - data_loader = DataLoader(dataset, num_workers=16, collate_fn=lambda x: x) - - # lmdb_path = osp.join(dpath, "%s.lmdb" % (directory.split('/')[-1])) - lmdb_path = osp.join("%s.lmdb" % (directory)) - isdir = os.path.isdir(lmdb_path) - - print("Generate LMDB to %s" % lmdb_path) - db = lmdbdict(lmdb_path, mode='w', key_method='ascii', value_method='identity') - - tsvfile = open(args.output_file, 'a') - writer = csv.DictWriter(tsvfile, delimiter='\t', fieldnames=FIELDNAMES) - names = [] - all_keys = [] - for idx, data in enumerate(tqdm.tqdm(data_loader)): - # print(type(data), data) - name, byte, npz = data[0] - if npz is not None: - db[name] = byte - all_keys.append(name) - names.append({'image_id': name, 'status': str(npz is not None)}) - if idx % write_frequency == 0: - print("[%d/%d]" % (idx, len(data_loader))) - print('writing') - db.flush() - # write in tsv - for name in names: - writer.writerow(name) - names = [] - tsvfile.flush() - print('writing finished') - # write all keys - # txn.put("keys".encode(), pickle.dumps(all_keys)) - # # finish iterating through dataset - # txn.commit() - for name in names: - writer.writerow(name) - tsvfile.flush() - tsvfile.close() - - print("Flushing database ...") - db.flush() - del db - -def parse_args(): - """ - Parse input arguments - """ - parser = argparse.ArgumentParser(description='Generate bbox output from a Fast R-CNN network') - # parser.add_argument('--json) - parser.add_argument('--input_json', default='./data/dataset_coco.json', type=str) - parser.add_argument('--output_file', default='.dump_cache.tsv', type=str) - parser.add_argument('--folder', default='./data/cocobu_att', type=str) - parser.add_argument('--extension', default='.npz', type=str) - - args = parser.parse_args() - return args - -if __name__ == "__main__": - global args - args = parse_args() - - args.output_file += args.folder.split('/')[-1] - if args.folder.find('/') > 0: - args.output_file = args.folder[:args.folder.rfind('/')+1]+args.output_file - print(args.output_file) - - img_list = json.load(open(args.input_json, 'r'))['images'] - fn_list = [str(_['cocoid']) for _ in img_list] - found_ids = set() - try: - with open(args.output_file, 'r') as tsvfile: - reader = csv.DictReader(tsvfile, delimiter='\t', fieldnames=FIELDNAMES) - for item in reader: - if item['status'] == 'True': - found_ids.add(item['image_id']) - except: - pass - fn_list = [_ for _ in fn_list if _ not in found_ids] - folder2lmdb(args.folder, fn_list) - - # Test existing. - found_ids = set() - with open(args.output_file, 'r') as tsvfile: - reader = csv.DictReader(tsvfile, delimiter='\t', fieldnames=FIELDNAMES) - for item in reader: - if item['status'] == 'True': - found_ids.add(item['image_id']) - - folder_dataset = FolderLMDB(args.folder+'.lmdb', list(found_ids)) - data_loader = DataLoader(folder_dataset, num_workers=16, collate_fn=lambda x: x) - for data in tqdm.tqdm(data_loader): - assert data[0] is not None \ No newline at end of file diff --git a/scripts/make_bu_data.py b/scripts/make_bu_data.py deleted file mode 100644 index 211f3e93dd3df9836e542322b0a19eeb581b2e1a..0000000000000000000000000000000000000000 --- a/scripts/make_bu_data.py +++ /dev/null @@ -1,52 +0,0 @@ -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import os -import base64 -import numpy as np -import csv -import sys -import zlib -import time -import mmap -import argparse - -parser = argparse.ArgumentParser() - -# output_dir -parser.add_argument('--downloaded_feats', default='data/bu_data', help='downloaded feature directory') -parser.add_argument('--output_dir', default='data/cocobu', help='output feature files') - -args = parser.parse_args() - -csv.field_size_limit(sys.maxsize) - - -FIELDNAMES = ['image_id', 'image_w','image_h','num_boxes', 'boxes', 'features'] -infiles = ['trainval/karpathy_test_resnet101_faster_rcnn_genome.tsv', - 'trainval/karpathy_val_resnet101_faster_rcnn_genome.tsv',\ - 'trainval/karpathy_train_resnet101_faster_rcnn_genome.tsv.0', \ - 'trainval/karpathy_train_resnet101_faster_rcnn_genome.tsv.1'] - -os.makedirs(args.output_dir+'_att') -os.makedirs(args.output_dir+'_fc') -os.makedirs(args.output_dir+'_box') - -for infile in infiles: - print('Reading ' + infile) - with open(os.path.join(args.downloaded_feats, infile), "r") as tsv_in_file: - reader = csv.DictReader(tsv_in_file, delimiter='\t', fieldnames = FIELDNAMES) - for item in reader: - item['image_id'] = int(item['image_id']) - item['num_boxes'] = int(item['num_boxes']) - for field in ['boxes', 'features']: - item[field] = np.frombuffer(base64.decodestring(item[field].encode('ascii')), - dtype=np.float32).reshape((item['num_boxes'],-1)) - np.savez_compressed(os.path.join(args.output_dir+'_att', str(item['image_id'])), feat=item['features']) - np.save(os.path.join(args.output_dir+'_fc', str(item['image_id'])), item['features'].mean(0)) - np.save(os.path.join(args.output_dir+'_box', str(item['image_id'])), item['boxes']) - - - - diff --git a/scripts/prepro_feats.py b/scripts/prepro_feats.py deleted file mode 100644 index 2c98d880d6b0b76ddb21f1bd516c4ce90515b8f3..0000000000000000000000000000000000000000 --- a/scripts/prepro_feats.py +++ /dev/null @@ -1,103 +0,0 @@ -""" -Preprocess a raw json dataset into features files for use in data_loader.py - -Input: json file that has the form -[{ file_path: 'path/img.jpg', captions: ['a caption', ...] }, ...] -example element in this list would look like -{'captions': [u'A man with a red helmet on a small moped on a dirt road. ', u'Man riding a motor bike on a dirt road on the countryside.', u'A man riding on the back of a motorcycle.', u'A dirt path with a young person on a motor bike rests to the foreground of a verdant area with a bridge and a background of cloud-wreathed mountains. ', u'A man in a red shirt and a red hat is on a motorcycle on a hill side.'], 'file_path': u'val2014/COCO_val2014_000000391895.jpg', 'id': 391895} - -This script reads this json, does some basic preprocessing on the captions -(e.g. lowercase, etc.), creates a special UNK token, and encodes everything to arrays - -Output: two folders of features -""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import os -import json -import argparse -from random import shuffle, seed -import string -# non-standard dependencies: -import h5py -from six.moves import cPickle -import numpy as np -import torch -import torchvision.models as models -import skimage.io - -from torchvision import transforms as trn -preprocess = trn.Compose([ - #trn.ToTensor(), - trn.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) -]) - -from captioning.utils.resnet_utils import myResnet -import captioning.utils.resnet as resnet - - -def main(params): - net = getattr(resnet, params['model'])() - net.load_state_dict(torch.load(os.path.join(params['model_root'],params['model']+'.pth'))) - my_resnet = myResnet(net) - my_resnet.cuda() - my_resnet.eval() - - imgs = json.load(open(params['input_json'], 'r')) - imgs = imgs['images'] - N = len(imgs) - - seed(123) # make reproducible - - dir_fc = params['output_dir']+'_fc' - dir_att = params['output_dir']+'_att' - if not os.path.isdir(dir_fc): - os.mkdir(dir_fc) - if not os.path.isdir(dir_att): - os.mkdir(dir_att) - - for i,img in enumerate(imgs): - # load the image - I = skimage.io.imread(os.path.join(params['images_root'], img['filepath'], img['filename'])) - # handle grayscale input images - if len(I.shape) == 2: - I = I[:,:,np.newaxis] - I = np.concatenate((I,I,I), axis=2) - - I = I.astype('float32')/255.0 - I = torch.from_numpy(I.transpose([2,0,1])).cuda() - I = preprocess(I) - with torch.no_grad(): - tmp_fc, tmp_att = my_resnet(I, params['att_size']) - # write to pkl - # print(dir_fc, str(img['cocoid']), tmp_fc.shape, tmp_att.shape, dir_att) - # exit() - np.save(os.path.join(dir_fc, str(img['cocoid'])), tmp_fc.data.cpu().float().numpy()) - np.savez_compressed(os.path.join(dir_att, str(img['cocoid'])), feat=tmp_att.data.cpu().float().numpy()) - - if i % 1000 == 0: - print('processing %d/%d (%.2f%% done)' % (i, N, i*100.0/N)) - print('wrote ', params['output_dir']) - -if __name__ == "__main__": - - parser = argparse.ArgumentParser() - - # input json - parser.add_argument('--input_json', required=True, help='input json file to process into hdf5') - parser.add_argument('--output_dir', default='data', help='output h5 file') - - # options - parser.add_argument('--images_root', default='', help='root location in which images are stored, to be prepended to file_path in input json') - parser.add_argument('--att_size', default=14, type=int, help='14x14 or 7x7') - parser.add_argument('--model', default='resnet101', type=str, help='resnet101, resnet152') - parser.add_argument('--model_root', default='./data/imagenet_weights', type=str, help='model root') - - args = parser.parse_args() - params = vars(args) # convert to ordinary dict - print('parsed input parameters:') - print(json.dumps(params, indent = 2)) - main(params) diff --git a/scripts/prepro_labels.py b/scripts/prepro_labels.py deleted file mode 100644 index 57fd82fb5144e51fd7dfe3e159080dbf29a63567..0000000000000000000000000000000000000000 --- a/scripts/prepro_labels.py +++ /dev/null @@ -1,206 +0,0 @@ -""" -Preprocess a raw json dataset into hdf5/json files for use in data_loader.py - -Input: json file that has the form -[{ file_path: 'path/img.jpg', captions: ['a caption', ...] }, ...] -example element in this list would look like -{'captions': [u'A man with a red helmet on a small moped on a dirt road. ', u'Man riding a motor bike on a dirt road on the countryside.', u'A man riding on the back of a motorcycle.', u'A dirt path with a young person on a motor bike rests to the foreground of a verdant area with a bridge and a background of cloud-wreathed mountains. ', u'A man in a red shirt and a red hat is on a motorcycle on a hill side.'], 'file_path': u'val2014/COCO_val2014_000000391895.jpg', 'id': 391895} - -This script reads this json, does some basic preprocessing on the captions -(e.g. lowercase, etc.), creates a special UNK token, and encodes everything to arrays - -Output: a json file and an hdf5 file -The hdf5 file contains several fields: -/labels is (M,max_length) uint32 array of encoded labels, zero padded -/label_start_ix and /label_end_ix are (N,) uint32 arrays of pointers to the - first and last indices (in range 1..M) of labels for each image -/label_length stores the length of the sequence for each of the M sequences - -The json file has a dict that contains: -- an 'ix_to_word' field storing the vocab in form {ix:'word'}, where ix is 1-indexed -- an 'images' field that is a list holding auxiliary information for each image, - such as in particular the 'split' it was assigned to. -""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import os -import json -import argparse -from random import shuffle, seed -import string -# non-standard dependencies: -import h5py -import numpy as np -import torch -import torchvision.models as models -import skimage.io -from PIL import Image - - -def build_vocab(imgs, params): - count_thr = params['word_count_threshold'] - - # count up the number of words - counts = {} - for img in imgs: - for sent in img['sentences']: - for w in sent['tokens']: - counts[w] = counts.get(w, 0) + 1 - cw = sorted([(count,w) for w,count in counts.items()], reverse=True) - print('top words and their counts:') - print('\n'.join(map(str,cw[:20]))) - - # print some stats - total_words = sum(counts.values()) - print('total words:', total_words) - bad_words = [w for w,n in counts.items() if n <= count_thr] - vocab = [w for w,n in counts.items() if n > count_thr] - bad_count = sum(counts[w] for w in bad_words) - print('number of bad words: %d/%d = %.2f%%' % (len(bad_words), len(counts), len(bad_words)*100.0/len(counts))) - print('number of words in vocab would be %d' % (len(vocab), )) - print('number of UNKs: %d/%d = %.2f%%' % (bad_count, total_words, bad_count*100.0/total_words)) - - # lets look at the distribution of lengths as well - sent_lengths = {} - for img in imgs: - for sent in img['sentences']: - txt = sent['tokens'] - nw = len(txt) - sent_lengths[nw] = sent_lengths.get(nw, 0) + 1 - max_len = max(sent_lengths.keys()) - print('max length sentence in raw data: ', max_len) - print('sentence length distribution (count, number of words):') - sum_len = sum(sent_lengths.values()) - for i in range(max_len+1): - print('%2d: %10d %f%%' % (i, sent_lengths.get(i,0), sent_lengths.get(i,0)*100.0/sum_len)) - - # lets now produce the final annotations - if bad_count > 0: - # additional special UNK token we will use below to map infrequent words to - print('inserting the special UNK token') - vocab.append('UNK') - - for img in imgs: - img['final_captions'] = [] - for sent in img['sentences']: - txt = sent['tokens'] - caption = [w if counts.get(w,0) > count_thr else 'UNK' for w in txt] - img['final_captions'].append(caption) - - return vocab - - -def encode_captions(imgs, params, wtoi): - """ - encode all captions into one large array, which will be 1-indexed. - also produces label_start_ix and label_end_ix which store 1-indexed - and inclusive (Lua-style) pointers to the first and last caption for - each image in the dataset. - """ - - max_length = params['max_length'] - N = len(imgs) - M = sum(len(img['final_captions']) for img in imgs) # total number of captions - - label_arrays = [] - label_start_ix = np.zeros(N, dtype='uint32') # note: these will be one-indexed - label_end_ix = np.zeros(N, dtype='uint32') - label_length = np.zeros(M, dtype='uint32') - caption_counter = 0 - counter = 1 - for i,img in enumerate(imgs): - n = len(img['final_captions']) - assert n > 0, 'error: some image has no captions' - - Li = np.zeros((n, max_length), dtype='uint32') - for j,s in enumerate(img['final_captions']): - label_length[caption_counter] = min(max_length, len(s)) # record the length of this sequence - caption_counter += 1 - for k,w in enumerate(s): - if k < max_length: - Li[j,k] = wtoi[w] - - # note: word indices are 1-indexed, and captions are padded with zeros - label_arrays.append(Li) - label_start_ix[i] = counter - label_end_ix[i] = counter + n - 1 - - counter += n - - L = np.concatenate(label_arrays, axis=0) # put all the labels together - assert L.shape[0] == M, 'lengths don\'t match? that\'s weird' - assert np.all(label_length > 0), 'error: some caption had no words?' - - print('encoded captions to array of size ', L.shape) - return L, label_start_ix, label_end_ix, label_length - - -def main(params): - - imgs = json.load(open(params['input_json'], 'r')) - imgs = imgs['images'] - - seed(123) # make reproducible - - # create the vocab - vocab = build_vocab(imgs, params) - itow = {i+1:w for i,w in enumerate(vocab)} # a 1-indexed vocab translation table - wtoi = {w:i+1 for i,w in enumerate(vocab)} # inverse table - - # encode captions in large arrays, ready to ship to hdf5 file - L, label_start_ix, label_end_ix, label_length = encode_captions(imgs, params, wtoi) - - # create output h5 file - N = len(imgs) - f_lb = h5py.File(params['output_h5']+'_label.h5', "w") - f_lb.create_dataset("labels", dtype='uint32', data=L) - f_lb.create_dataset("label_start_ix", dtype='uint32', data=label_start_ix) - f_lb.create_dataset("label_end_ix", dtype='uint32', data=label_end_ix) - f_lb.create_dataset("label_length", dtype='uint32', data=label_length) - f_lb.close() - - # create output json file - out = {} - out['ix_to_word'] = itow # encode the (1-indexed) vocab - out['images'] = [] - for i,img in enumerate(imgs): - - jimg = {} - jimg['split'] = img['split'] - if 'filename' in img: jimg['file_path'] = os.path.join(img.get('filepath', ''), img['filename']) # copy it over, might need - if 'cocoid' in img: - jimg['id'] = img['cocoid'] # copy over & mantain an id, if present (e.g. coco ids, useful) - elif 'imgid' in img: - jimg['id'] = img['imgid'] - - if params['images_root'] != '': - with Image.open(os.path.join(params['images_root'], img['filepath'], img['filename'])) as _img: - jimg['width'], jimg['height'] = _img.size - - out['images'].append(jimg) - - json.dump(out, open(params['output_json'], 'w')) - print('wrote ', params['output_json']) - -if __name__ == "__main__": - - parser = argparse.ArgumentParser() - - # input json - parser.add_argument('--input_json', required=True, help='input json file to process into hdf5') - parser.add_argument('--output_json', default='data.json', help='output json file') - parser.add_argument('--output_h5', default='data', help='output h5 file') - parser.add_argument('--images_root', default='', help='root location in which images are stored, to be prepended to file_path in input json') - - # options - parser.add_argument('--max_length', default=16, type=int, help='max length of a caption, in number of words. captions longer than this get clipped.') - parser.add_argument('--word_count_threshold', default=5, type=int, help='only words that occur more than this number of times will be put in vocab') - - args = parser.parse_args() - params = vars(args) # convert to ordinary dict - print('parsed input parameters:') - print(json.dumps(params, indent = 2)) - main(params) diff --git a/scripts/prepro_ngrams.py b/scripts/prepro_ngrams.py deleted file mode 100644 index f7cdce47deddaae19b97b24c4191c99fcf8f9cb8..0000000000000000000000000000000000000000 --- a/scripts/prepro_ngrams.py +++ /dev/null @@ -1,94 +0,0 @@ -""" -Precompute ngram counts of captions, to accelerate cider computation during training time. -""" - -import os -import json -import argparse -from six.moves import cPickle -import captioning.utils.misc as utils -from collections import defaultdict - -import sys -sys.path.append("cider") -from pyciderevalcap.ciderD.ciderD_scorer import CiderScorer - - -def get_doc_freq(refs, params): - tmp = CiderScorer(df_mode="corpus") - for ref in refs: - tmp.cook_append(None, ref) - tmp.compute_doc_freq() - return tmp.document_frequency, len(tmp.crefs) - - -def build_dict(imgs, wtoi, params): - wtoi[''] = 0 - - count_imgs = 0 - - refs_words = [] - refs_idxs = [] - for img in imgs: - if (params['split'] == img['split']) or \ - (params['split'] == 'train' and img['split'] == 'restval') or \ - (params['split'] == 'all'): - #(params['split'] == 'val' and img['split'] == 'restval') or \ - ref_words = [] - ref_idxs = [] - for sent in img['sentences']: - if hasattr(params, 'bpe'): - sent['tokens'] = params.bpe.segment(' '.join(sent['tokens'])).strip().split(' ') - tmp_tokens = sent['tokens'] + [''] - tmp_tokens = [_ if _ in wtoi else 'UNK' for _ in tmp_tokens] - ref_words.append(' '.join(tmp_tokens)) - ref_idxs.append(' '.join([str(wtoi[_]) for _ in tmp_tokens])) - refs_words.append(ref_words) - refs_idxs.append(ref_idxs) - count_imgs += 1 - print('total imgs:', count_imgs) - - ngram_words, count_refs = get_doc_freq(refs_words, params) - ngram_idxs, count_refs = get_doc_freq(refs_idxs, params) - print('count_refs:', count_refs) - return ngram_words, ngram_idxs, count_refs - -def main(params): - - imgs = json.load(open(params['input_json'], 'r')) - dict_json = json.load(open(params['dict_json'], 'r')) - itow = dict_json['ix_to_word'] - wtoi = {w:i for i,w in itow.items()} - - # Load bpe - if 'bpe' in dict_json: - import tempfile - import codecs - codes_f = tempfile.NamedTemporaryFile(delete=False) - codes_f.close() - with open(codes_f.name, 'w') as f: - f.write(dict_json['bpe']) - with codecs.open(codes_f.name, encoding='UTF-8') as codes: - bpe = apply_bpe.BPE(codes) - params.bpe = bpe - - imgs = imgs['images'] - - ngram_words, ngram_idxs, ref_len = build_dict(imgs, wtoi, params) - - utils.pickle_dump({'document_frequency': ngram_words, 'ref_len': ref_len}, open(params['output_pkl']+'-words.p','wb')) - utils.pickle_dump({'document_frequency': ngram_idxs, 'ref_len': ref_len}, open(params['output_pkl']+'-idxs.p','wb')) - -if __name__ == "__main__": - - parser = argparse.ArgumentParser() - - # input json - parser.add_argument('--input_json', default='data/dataset_coco.json', help='input json file to process into hdf5') - parser.add_argument('--dict_json', default='data/cocotalk.json', help='output json file') - parser.add_argument('--output_pkl', default='data/coco-all', help='output pickle file') - parser.add_argument('--split', default='all', help='test, val, train, all') - args = parser.parse_args() - params = vars(args) # convert to ordinary dict - - main(params) diff --git a/scripts/prepro_reference_json.py b/scripts/prepro_reference_json.py deleted file mode 100644 index 683b12b03e0ef5768af2b11d359dc1f814a1e39b..0000000000000000000000000000000000000000 --- a/scripts/prepro_reference_json.py +++ /dev/null @@ -1,69 +0,0 @@ -# coding: utf-8 -""" -Create a reference json file used for evaluation with `coco-caption` repo. -Used when reference json is not provided, (e.g., flickr30k, or you have your own split of train/val/test) -""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import os -import json -import argparse -import sys -import hashlib -from random import shuffle, seed - - -def main(params): - - imgs = json.load(open(params['input_json'][0], 'r'))['images'] - # tmp = [] - # for k in imgs.keys(): - # for img in imgs[k]: - # img['filename'] = img['image_id'] # k+'/'+img['image_id'] - # img['image_id'] = int( - # int(hashlib.sha256(img['image_id']).hexdigest(), 16) % sys.maxint) - # tmp.append(img) - # imgs = tmp - - # create output json file - out = {'info': {'description': 'This is stable 1.0 version of the 2014 MS COCO dataset.', 'url': 'http://mscoco.org', 'version': '1.0', 'year': 2014, 'contributor': 'Microsoft COCO group', 'date_created': '2015-01-27 09:11:52.357475'}, 'licenses': [{'url': 'http://creativecommons.org/licenses/by-nc-sa/2.0/', 'id': 1, 'name': 'Attribution-NonCommercial-ShareAlike License'}, {'url': 'http://creativecommons.org/licenses/by-nc/2.0/', 'id': 2, 'name': 'Attribution-NonCommercial License'}, {'url': 'http://creativecommons.org/licenses/by-nc-nd/2.0/', 'id': 3, 'name': 'Attribution-NonCommercial-NoDerivs License'}, {'url': 'http://creativecommons.org/licenses/by/2.0/', 'id': 4, 'name': 'Attribution License'}, {'url': 'http://creativecommons.org/licenses/by-sa/2.0/', 'id': 5, 'name': 'Attribution-ShareAlike License'}, {'url': 'http://creativecommons.org/licenses/by-nd/2.0/', 'id': 6, 'name': 'Attribution-NoDerivs License'}, {'url': 'http://flickr.com/commons/usage/', 'id': 7, 'name': 'No known copyright restrictions'}, {'url': 'http://www.usa.gov/copyright.shtml', 'id': 8, 'name': 'United States Government Work'}], 'type': 'captions'} - out.update({'images': [], 'annotations': []}) - - cnt = 0 - empty_cnt = 0 - for i, img in enumerate(imgs): - if img['split'] == 'train': - continue - out['images'].append( - {'id': img.get('cocoid', img['imgid'])}) - for j, s in enumerate(img['sentences']): - if len(s) == 0: - continue - s = ' '.join(s['tokens']) - out['annotations'].append( - {'image_id': out['images'][-1]['id'], 'caption': s, 'id': cnt}) - cnt += 1 - - json.dump(out, open(params['output_json'], 'w')) - print('wrote ', params['output_json']) - - -if __name__ == "__main__": - - parser = argparse.ArgumentParser() - - # input json - parser.add_argument('--input_json', nargs='+', required=True, - help='input json file to process into hdf5') - parser.add_argument('--output_json', default='data.json', - help='output json file') - - args = parser.parse_args() - params = vars(args) # convert to ordinary dict - print('parsed input parameters:') - print(json.dumps(params, indent=2)) - main(params) - diff --git a/scripts_FineCapEval/clip_prepro_feats.py b/scripts_FineCapEval/clip_prepro_feats.py deleted file mode 100644 index 3b2986a54766752ac353e09e064a6e8abb43e0d5..0000000000000000000000000000000000000000 --- a/scripts_FineCapEval/clip_prepro_feats.py +++ /dev/null @@ -1,163 +0,0 @@ -""" -Preprocess a raw json dataset into features files for use in data_loader.py - -Input: json file that has the form -[{ file_path: 'path/img.jpg', captions: ['a caption', ...] }, ...] -example element in this list would look like -{'captions': [u'A man with a red helmet on a small moped on a dirt road. ', u'Man riding a motor bike on a dirt road on the countryside.', u'A man riding on the back of a motorcycle.', u'A dirt path with a young person on a motor bike rests to the foreground of a verdant area with a bridge and a background of cloud-wreathed mountains. ', u'A man in a red shirt and a red hat is on a motorcycle on a hill side.'], 'file_path': u'val2014/COCO_val2014_000000391895.jpg', 'id': 391895} - -This script reads this json, does some basic preprocessing on the captions -(e.g. lowercase, etc.), creates a special UNK token, and encodes everything to arrays - -Output: two folders of features -""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import os -import json -import argparse -from random import shuffle, seed -import string -# non-standard dependencies: -import h5py -from six.moves import cPickle -import numpy as np -import torch -import torchvision.models as models -import skimage.io - -from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize -from PIL import Image -from torch import nn - -preprocess = Compose([ - Resize((448, 448), interpolation=Image.BICUBIC), - CenterCrop((448, 448)), - ToTensor() -]) - - -from clip.clip import load -from timm.models.vision_transformer import resize_pos_embed -import timm - -from captioning.utils.resnet_utils import myResnet -import captioning.utils.resnet as resnet - -from tqdm import tqdm - - -def main(params): - if params["model_type"] != 'vit_base_patch32_224_in21k': - model, transform = load(params["model_type"], jit=False) - else: - model = timm.create_model(params["model_type"], pretrained=True) - model = model.cuda() - - if params["model_type"] != 'vit_base_patch32_224_in21k': - save_model_type = params["model_type"].split("-")[0] - mean = torch.Tensor([0.48145466, 0.4578275, 0.40821073]).to("cuda").reshape(3, 1, 1) - std = torch.Tensor([0.26862954, 0.26130258, 0.27577711]).to("cuda").reshape(3, 1, 1) - - if "RN" in params["model_type"]: - num_patches = 196 #600 * 1000 // 32 // 32 - pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, model.visual.attnpool.positional_embedding.shape[-1], device='cuda'),) - pos_embed.weight = resize_pos_embed(model.visual.attnpool.positional_embedding.unsqueeze(0), pos_embed) - model.visual.attnpool.positional_embedding = pos_embed - - else: - save_model_type = 'vit_base' - mean = torch.Tensor([0.5, 0.5, 0.5]).to("cuda").reshape(3, 1, 1) - std = torch.Tensor([0.5, 0.5, 0.5]).to("cuda").reshape(3, 1, 1) - - num_patches = 196 #600 * 1000 // 32 // 32 - pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, 768, device='cuda'),) - pos_embed.weight = resize_pos_embed(model.pos_embed, pos_embed) - model.pos_embed = pos_embed - - if params["model_type"] == "ViT-B/32": - num_patches = 196 #600 * 1000 // 32 // 32 - pos_embed = nn.Parameter(torch.zeros(num_patches + 1, 768, device='cuda'),) - pos_embed.weight = resize_pos_embed(model.visual.positional_embedding.unsqueeze(0), pos_embed.unsqueeze(0)) - model.visual.positional_embedding = pos_embed - imgs = json.load(open(params['input_json'], 'r')) - imgs = imgs['images'] - N = len(imgs) - - seed(123) # make reproducible - - dir_fc = params['output_dir']+'_clip_'+save_model_type+'_fc' - dir_att = params['output_dir']+'_clip_'+save_model_type+'_att' - if not os.path.isdir(dir_fc): - os.mkdir(dir_fc) - if not os.path.isdir(dir_att): - os.mkdir(dir_att) - - for i, img in enumerate(tqdm(imgs)): - with torch.no_grad(): - - # img_path = os.path.join(params['images_root'], img['filepath'], img['filename']) - # img_path = os.path.join(params['images_root'], img['file_name']) - - img_path = os.path.join(params['images_root'], img['file_path']) - - image = preprocess(Image.open( img_path ).convert("RGB")) - image = torch.tensor(np.stack([image])).cuda() - image -= mean - image /= std - if "RN" in params["model_type"]: - tmp_att, tmp_fc = model.encode_image(image) - tmp_att = tmp_att[0].permute(1, 2, 0) - tmp_fc = tmp_fc[0] - elif params["model_type"] == 'vit_base_patch32_224_in21k': - x = model(image) - tmp_fc = x[0, 0, :] - tmp_att = x[0, 1:, :].reshape( 14, 14, 768 ) - else: - x = model.visual.conv1(image.half()) # shape = [*, width, grid, grid] - x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2] - x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] - x = torch.cat([model.visual.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width] - x = x + model.visual.positional_embedding.to(x.dtype)[:x.shape[1], :] - x = model.visual.ln_pre(x) - - x = x.permute(1, 0, 2) # NLD -> LND - - for layer_idx, layer in enumerate(model.visual.transformer.resblocks): - x = layer(x) - - x = x.permute(1, 0, 2) - tmp_fc = x[0, 0, :] - tmp_att = x[0, 1:, :].reshape( 14, 14, 768 ) - - # np.save(os.path.join(dir_fc, str(img['cocoid'])), tmp_fc.data.cpu().float().numpy()) - # np.savez_compressed(os.path.join(dir_att, str(img['cocoid'])), feat=tmp_att.data.cpu().float().numpy()) - np.save(os.path.join(dir_fc, str(img['id'])), tmp_fc.data.cpu().float().numpy()) - np.savez_compressed(os.path.join(dir_att, str(img['id'])), feat=tmp_att.data.cpu().float().numpy()) - - - # if i % 1000 == 0: - # print('processing %d/%d (%.2f%% done)' % (i, N, i*100.0/N)) - print('wrote ', dir_fc, dir_att) - -if __name__ == "__main__": - - parser = argparse.ArgumentParser() - - # input json - parser.add_argument('--input_json', required=True, help='input json file to process into hdf5') - parser.add_argument('--output_dir', default='data', help='output h5 file') - - # options - parser.add_argument('--images_root', default='', help='root location in which images are stored, to be prepended to file_path in input json') - parser.add_argument('--att_size', default=14, type=int, help='14x14 or 7x7') - parser.add_argument('--model_type', default='RN50', type=str, help='RN50, RN101, RN50x4, ViT-B/32, vit_base_patch32_224_in21k') - - args = parser.parse_args() - params = vars(args) # convert to ordinary dict - print('parsed input parameters:') - print(json.dumps(params, indent = 2)) - main(params) diff --git a/scripts_FineCapEval/clipscore_prepro_feats.py b/scripts_FineCapEval/clipscore_prepro_feats.py deleted file mode 100644 index 5e085078ecd67e4e390bc50b543c14d4934cb260..0000000000000000000000000000000000000000 --- a/scripts_FineCapEval/clipscore_prepro_feats.py +++ /dev/null @@ -1,154 +0,0 @@ -""" -Preprocess a raw json dataset into features files for use in data_loader.py - -Input: json file that has the form -[{ file_path: 'path/img.jpg', captions: ['a caption', ...] }, ...] -example element in this list would look like -{'captions': [u'A man with a red helmet on a small moped on a dirt road. ', u'Man riding a motor bike on a dirt road on the countryside.', u'A man riding on the back of a motorcycle.', u'A dirt path with a young person on a motor bike rests to the foreground of a verdant area with a bridge and a background of cloud-wreathed mountains. ', u'A man in a red shirt and a red hat is on a motorcycle on a hill side.'], 'file_path': u'val2014/COCO_val2014_000000391895.jpg', 'id': 391895} - -This script reads this json, does some basic preprocessing on the captions -(e.g. lowercase, etc.), creates a special UNK token, and encodes everything to arrays - -Output: two folders of features -""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import os -import json -import argparse -from random import shuffle, seed -import string -# non-standard dependencies: -import h5py -from six.moves import cPickle -import numpy as np -import torch -import torchvision.models as models -import skimage.io - -from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize -from PIL import Image -from torch import nn - -# preprocess = Compose([ -# Resize((448, 448), interpolation=Image.BICUBIC), -# CenterCrop((448, 448)), -# ToTensor() -# ]) - - -# from clip.clip import load -# from timm.models.vision_transformer import resize_pos_embed -# import timm - -# from captioning.utils.resnet_utils import myResnet -# import captioning.utils.resnet as resnet - -from captioning.utils.clipscore import CLIPScore - -from tqdm import tqdm - - -def main(params): - - clipscore_model = CLIPScore() - clipscore_model.to('cuda') - - imgs = json.load(open(params['input_json'], 'r')) - imgs = imgs['images'] - N = len(imgs) - - seed(123) # make reproducible - - # dir_fc = params['output_dir']+'_clip_'+save_model_type+'_fc' - # dir_att = params['output_dir']+'_clip_'+save_model_type+'_att' - - vis_dir_fc = params['output_dir']+'_clipscore_vis' - if not os.path.isdir(vis_dir_fc): - os.mkdir(vis_dir_fc) - - # text_dir_fc = params['output_dir']+'_clipscore_text' - # if not os.path.isdir(text_dir_fc): - # os.mkdir(text_dir_fc) - - # if not os.path.isdir(dir_att): - # os.mkdir(dir_att) - - for i,img in enumerate(tqdm(imgs)): - # load the image - - # img_path = os.path.join(params['images_root'], img['filepath'], img['filename']) - # img_path = os.path.join(params['images_root'], img['file_name']) - img_path = os.path.join(params['images_root'], img['file_path']) - - img_feat = clipscore_model.image_extract(img_path) - img_feat = img_feat.view(512) - - # for d in img['sentences']: - # text = d['raw'].strip() - # text_feat = clipscore_model.text_extract(text) - - - # with torch.no_grad(): - - # image = preprocess(Image.open(os.path.join(params['images_root'], img['filepath'], img['filename']) ).convert("RGB")) - # image = torch.tensor(np.stack([image])).cuda() - # image -= mean - # image /= std - # if "RN" in params["model_type"]: - # tmp_att, tmp_fc = model.encode_image(image) - # tmp_att = tmp_att[0].permute(1, 2, 0) - # tmp_fc = tmp_fc[0] - # elif params["model_type"] == 'vit_base_patch32_224_in21k': - # x = model(image) - # tmp_fc = x[0, 0, :] - # tmp_att = x[0, 1:, :].reshape( 14, 14, 768 ) - # else: - # x = model.visual.conv1(image.half()) # shape = [*, width, grid, grid] - # x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2] - # x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] - # x = torch.cat([model.visual.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width] - # x = x + model.visual.positional_embedding.to(x.dtype)[:x.shape[1], :] - # x = model.visual.ln_pre(x) - - # x = x.permute(1, 0, 2) # NLD -> LND - - # for layer_idx, layer in enumerate(model.visual.transformer.resblocks): - # x = layer(x) - - # x = x.permute(1, 0, 2) - # tmp_fc = x[0, 0, :] - # tmp_att = x[0, 1:, :].reshape( 14, 14, 768 ) - - np.save(os.path.join(vis_dir_fc, str(img['id'])), img_feat.data.cpu().float().numpy()) - # np.save(os.path.join(text_dir_fc, str(img['cocoid'])), tmp_fc.data.cpu().float().numpy()) - - - # np.savez_compressed(os.path.join(dir_att, str(img['cocoid'])), feat=tmp_att.data.cpu().float().numpy()) - - # if i % 1000 == 0: - # print('processing %d/%d (%.2f%% done)' % (i, N, i*100.0/N)) - print('wrote ', vis_dir_fc) - -if __name__ == "__main__": - - parser = argparse.ArgumentParser() - - # input json - # dataset_coco.json - parser.add_argument('--input_json', required=True, help='input json file to process into hdf5') - parser.add_argument('--output_dir', default='data', help='output h5 file') - - # options - parser.add_argument('--images_root', default='', help='root location in which images are stored, to be prepended to file_path in input json') - # parser.add_argument('--att_size', default=14, type=int, help='14x14 or 7x7') - # parser.add_argument('--model_type', default='RN50', type=str, help='RN50, RN101, RN50x4, ViT-B/32, vit_base_patch32_224_in21k') - - args = parser.parse_args() - params = vars(args) # convert to ordinary dict - print('parsed input parameters:') - print(json.dumps(params, indent = 2)) - main(params) diff --git a/scripts_FineCapEval/prepro_labels.py b/scripts_FineCapEval/prepro_labels.py deleted file mode 100644 index 48e7d079808760941a78d87435f8f0e2bbcfb280..0000000000000000000000000000000000000000 --- a/scripts_FineCapEval/prepro_labels.py +++ /dev/null @@ -1,209 +0,0 @@ -""" -Preprocess a raw json dataset into hdf5/json files for use in data_loader.py - -Input: json file that has the form -[{ file_path: 'path/img.jpg', captions: ['a caption', ...] }, ...] -example element in this list would look like -{'captions': [u'A man with a red helmet on a small moped on a dirt road. ', u'Man riding a motor bike on a dirt road on the countryside.', u'A man riding on the back of a motorcycle.', u'A dirt path with a young person on a motor bike rests to the foreground of a verdant area with a bridge and a background of cloud-wreathed mountains. ', u'A man in a red shirt and a red hat is on a motorcycle on a hill side.'], 'file_path': u'val2014/COCO_val2014_000000391895.jpg', 'id': 391895} - -This script reads this json, does some basic preprocessing on the captions -(e.g. lowercase, etc.), creates a special UNK token, and encodes everything to arrays - -Output: a json file and an hdf5 file -The hdf5 file contains several fields: -/labels is (M,max_length) uint32 array of encoded labels, zero padded -/label_start_ix and /label_end_ix are (N,) uint32 arrays of pointers to the - first and last indices (in range 1..M) of labels for each image -/label_length stores the length of the sequence for each of the M sequences - -The json file has a dict that contains: -- an 'ix_to_word' field storing the vocab in form {ix:'word'}, where ix is 1-indexed -- an 'images' field that is a list holding auxiliary information for each image, - such as in particular the 'split' it was assigned to. -""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import os -import json -import argparse -from random import shuffle, seed -import string -# non-standard dependencies: -import h5py -import numpy as np -import torch -import torchvision.models as models -import skimage.io -from PIL import Image - - -def build_vocab(imgs, params): - count_thr = params['word_count_threshold'] - - # count up the number of words - counts = {} - for img in imgs: - for sent in img['sentences']: - for w in sent['tokens']: - counts[w] = counts.get(w, 0) + 1 - cw = sorted([(count,w) for w,count in counts.items()], reverse=True) - print('top words and their counts:') - print('\n'.join(map(str,cw[:20]))) - - # print some stats - total_words = sum(counts.values()) - print('total words:', total_words) - bad_words = [w for w,n in counts.items() if n <= count_thr] - vocab = [w for w,n in counts.items() if n > count_thr] - bad_count = sum(counts[w] for w in bad_words) - print('number of bad words: %d/%d = %.2f%%' % (len(bad_words), len(counts), len(bad_words)*100.0/len(counts))) - print('number of words in vocab would be %d' % (len(vocab), )) - print('number of UNKs: %d/%d = %.2f%%' % (bad_count, total_words, bad_count*100.0/total_words)) - - # lets look at the distribution of lengths as well - sent_lengths = {} - for img in imgs: - for sent in img['sentences']: - txt = sent['tokens'] - nw = len(txt) - sent_lengths[nw] = sent_lengths.get(nw, 0) + 1 - max_len = max(sent_lengths.keys()) - print('max length sentence in raw data: ', max_len) - print('sentence length distribution (count, number of words):') - sum_len = sum(sent_lengths.values()) - for i in range(max_len+1): - print('%2d: %10d %f%%' % (i, sent_lengths.get(i,0), sent_lengths.get(i,0)*100.0/sum_len)) - - # lets now produce the final annotations - if bad_count > 0: - # additional special UNK token we will use below to map infrequent words to - print('inserting the special UNK token') - vocab.append('UNK') - - for img in imgs: - img['final_captions'] = [] - for sent in img['sentences']: - txt = sent['tokens'] - caption = [w if counts.get(w,0) > count_thr else 'UNK' for w in txt] - img['final_captions'].append(caption) - - return vocab - - -def encode_captions(imgs, params, wtoi): - """ - encode all captions into one large array, which will be 1-indexed. - also produces label_start_ix and label_end_ix which store 1-indexed - and inclusive (Lua-style) pointers to the first and last caption for - each image in the dataset. - """ - - max_length = params['max_length'] - N = len(imgs) - M = sum(len(img['final_captions']) for img in imgs) # total number of captions - - label_arrays = [] - label_start_ix = np.zeros(N, dtype='uint32') # note: these will be one-indexed - label_end_ix = np.zeros(N, dtype='uint32') - label_length = np.zeros(M, dtype='uint32') - caption_counter = 0 - counter = 1 - for i,img in enumerate(imgs): - n = len(img['final_captions']) - assert n > 0, 'error: some image has no captions' - - Li = np.zeros((n, max_length), dtype='uint32') - for j,s in enumerate(img['final_captions']): - label_length[caption_counter] = min(max_length, len(s)) # record the length of this sequence - caption_counter += 1 - for k,w in enumerate(s): - if k < max_length: - Li[j,k] = wtoi[w] - - # note: word indices are 1-indexed, and captions are padded with zeros - label_arrays.append(Li) - label_start_ix[i] = counter - label_end_ix[i] = counter + n - 1 - - counter += n - - L = np.concatenate(label_arrays, axis=0) # put all the labels together - assert L.shape[0] == M, 'lengths don\'t match? that\'s weird' - assert np.all(label_length > 0), 'error: some caption had no words?' - - print('encoded captions to array of size ', L.shape) - return L, label_start_ix, label_end_ix, label_length - - -def main(params): - - imgs = json.load(open(params['input_json'], 'r')) - imgs = imgs['images'] - - seed(123) # make reproducible - - # # create the vocab - # vocab = build_vocab(imgs, params) - # itow = {i+1:w for i,w in enumerate(vocab)} # a 1-indexed vocab translation table - # wtoi = {w:i+1 for i,w in enumerate(vocab)} # inverse table - - itow = imgs['ix_to_word'] - wtoi = {w:i for i, w in itow.items()} - - # encode captions in large arrays, ready to ship to hdf5 file - L, label_start_ix, label_end_ix, label_length = encode_captions(imgs, params, wtoi) - - # create output h5 file - N = len(imgs) - f_lb = h5py.File(params['output_h5']+'_label.h5', "w") - f_lb.create_dataset("labels", dtype='uint32', data=L) - f_lb.create_dataset("label_start_ix", dtype='uint32', data=label_start_ix) - f_lb.create_dataset("label_end_ix", dtype='uint32', data=label_end_ix) - f_lb.create_dataset("label_length", dtype='uint32', data=label_length) - f_lb.close() - - # create output json file - out = {} - out['ix_to_word'] = itow # encode the (1-indexed) vocab - out['images'] = [] - for i,img in enumerate(imgs): - - jimg = {} - jimg['split'] = img['split'] - if 'filename' in img: jimg['file_path'] = os.path.join(img.get('filepath', ''), img['filename']) # copy it over, might need - if 'cocoid' in img: - jimg['id'] = img['cocoid'] # copy over & mantain an id, if present (e.g. coco ids, useful) - elif 'imgid' in img: - jimg['id'] = img['imgid'] - - if params['images_root'] != '': - with Image.open(os.path.join(params['images_root'], img['filepath'], img['filename'])) as _img: - jimg['width'], jimg['height'] = _img.size - - out['images'].append(jimg) - - json.dump(out, open(params['output_json'], 'w')) - print('wrote ', params['output_json']) - -if __name__ == "__main__": - - parser = argparse.ArgumentParser() - - # input json - parser.add_argument('--input_json', required=True, help='input json file to process into hdf5') - parser.add_argument('--output_json', default='data.json', help='output json file') - parser.add_argument('--output_h5', default='data', help='output h5 file') - parser.add_argument('--images_root', default='', help='root location in which images are stored, to be prepended to file_path in input json') - - # options - parser.add_argument('--max_length', default=16, type=int, help='max length of a caption, in number of words. captions longer than this get clipped.') - parser.add_argument('--word_count_threshold', default=5, type=int, help='only words that occur more than this number of times will be put in vocab') - - args = parser.parse_args() - params = vars(args) # convert to ordinary dict - print('parsed input parameters:') - print(json.dumps(params, indent = 2)) - main(params) diff --git a/tools/eval.py b/tools/eval.py deleted file mode 100644 index 881580737fa554344b1b66ab79c4f1de114759ca..0000000000000000000000000000000000000000 --- a/tools/eval.py +++ /dev/null @@ -1,125 +0,0 @@ -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import json -import numpy as np - -import time -import os -from six.moves import cPickle - -import captioning.utils.opts as opts -import captioning.models as models -from captioning.data.dataloader import * -# from captioning.data.dataloaderraw import * -import captioning.utils.eval_utils as eval_utils -import argparse -import captioning.utils.misc as utils -import captioning.modules.losses as losses -import torch - -# Input arguments and options -parser = argparse.ArgumentParser() -# Input paths -parser.add_argument('--model', type=str, default='', - help='path to model to evaluate') -parser.add_argument('--cnn_model', type=str, default='resnet101', - help='resnet101, resnet152') -parser.add_argument('--infos_path', type=str, default='', - help='path to infos to evaluate') -parser.add_argument('--only_lang_eval', type=int, default=0, - help='lang eval on saved results') -parser.add_argument('--force', type=int, default=0, - help='force to evaluate no matter if there are results available') -parser.add_argument('--device', type=str, default='cuda', - help='cpu or cuda') -opts.add_eval_options(parser) -opts.add_diversity_opts(parser) -opt = parser.parse_args() - -# Load infos -with open(opt.infos_path, 'rb') as f: - infos = utils.pickle_load(f) - -# override and collect parameters -replace = ['input_fc_dir', 'input_att_dir', 'input_box_dir', 'input_label_h5', 'input_json', 'batch_size', 'id'] -ignore = ['start_from'] - -for k in vars(infos['opt']).keys(): - if k in replace: - setattr(opt, k, getattr(opt, k) or getattr(infos['opt'], k, '')) - elif k not in ignore: - if not k in vars(opt): - vars(opt).update({k: vars(infos['opt'])[k]}) # copy over options from model - -vocab = infos['vocab'] # ix -> word mapping - -pred_fn = os.path.join('eval_results/', '.saved_pred_'+ opt.id + '_' + opt.split + '.pth') -result_fn = os.path.join('eval_results/', opt.id + '_' + opt.split + '.json') - -if opt.only_lang_eval == 1 or (not opt.force and os.path.isfile(pred_fn)): - # if results existed, then skip, unless force is on - if not opt.force: - try: - if os.path.isfile(result_fn): - print(result_fn) - json.load(open(result_fn, 'r')) - print('already evaluated') - os._exit(0) - except: - pass - - predictions, n_predictions = torch.load(pred_fn) - lang_stats = eval_utils.language_eval(opt.input_json, predictions, n_predictions, vars(opt), opt.split) - print(lang_stats) - os._exit(0) - -# At this point only_lang_eval if 0 -if not opt.force: - # Check out if - try: - # if no pred exists, then continue - tmp = torch.load(pred_fn) - # if language_eval == 1, and no pred exists, then continue - if opt.language_eval == 1: - json.load(open(result_fn, 'r')) - print('Result is already there') - os._exit(0) - except: - pass - -# Setup the model -opt.vocab = vocab -model = models.setup(opt) -del opt.vocab -model.load_state_dict(torch.load(opt.model, map_location='cpu')) -model.to(opt.device) -model.eval() -crit = losses.LanguageModelCriterion() - -# Create the Data Loader instance -if len(opt.image_folder) == 0: - loader = DataLoader(opt) -else: - loader = DataLoaderRaw({'folder_path': opt.image_folder, - 'coco_json': opt.coco_json, - 'batch_size': opt.batch_size, - 'cnn_model': opt.cnn_model}) -# When eval using provided pretrained model, the vocab may be different from what you have in your cocotalk.json -# So make sure to use the vocab in infos file. -loader.dataset.ix_to_word = infos['vocab'] - - -# Set sample options -opt.dataset = opt.input_json -loss, split_predictions, lang_stats = eval_utils.eval_split(model, crit, loader, - vars(opt)) - -print('loss: ', loss) -if lang_stats: - print(lang_stats) - -if opt.dump_json == 1: - # dump the json - json.dump(split_predictions, open('vis/vis.json', 'w')) diff --git a/tools/eval_clip_retrieval.py b/tools/eval_clip_retrieval.py deleted file mode 100644 index 639aace1d4558094d93f4bc3a8643e883b26785b..0000000000000000000000000000000000000000 --- a/tools/eval_clip_retrieval.py +++ /dev/null @@ -1,231 +0,0 @@ - -from PIL import Image -# import requests - -from transformers import CLIPProcessor, CLIPModel - -import torch -from torch.utils.data import DataLoader, Dataset - -from pathlib import Path -from tqdm import tqdm -import json -import argparse -import numpy as np - -class COCODataset(Dataset): - def __init__(self, - coco_root="/nas-ssd/jmincho/datasets/COCO/", - gen_caption_path=None, - is_gt=True): - super().__init__() - - self.coco_root = Path(coco_root) - - self.image_dir = self.coco_root.joinpath('images/val2014') - - if is_gt: - print("Loading karpathy splits") - data_info_path = self.coco_root.joinpath('dataset_coco.json') - with open(data_info_path) as f: - karpathy_data = json.load(f) - - data = [] - for datum in karpathy_data['images']: - # karpathy test split - if datum['split'] == 'test': - img_id = datum['filename'].split('.')[0] - new_datum = { - 'img_id': img_id, - 'captions': [d['raw'].strip() for d in datum['sentences']], - } - data.append(new_datum) - else: - print("Loading generated captions") - gen_caption_path = Path(gen_caption_path) - with open(gen_caption_path) as f: - # karpathy_data = json.load(f) - imgTogen_results = json.load(f)['imgToEval'] - data = [] - for img_id, img_data in imgTogen_results.items(): - new_datum = { - 'img_id': img_id, - 'captions': [img_data['caption']], - } - data.append(new_datum) - - self.data = data - print('# images:', len(self.data)) - - self.img_transform = processor.feature_extractor - self.tokenizer = processor.tokenizer - - def __len__(self): - return len(self.data) - - def __getitem__(self, idx): - datum = self.data[idx] - img_id = datum['img_id'] - if 'COCO' not in img_id: - img_id = f'COCO_val2014_{str(img_id).zfill(12)}' - img_fname = f"{img_id}.jpg" - # COCO_val2014_000000522418.jpg - img_path = self.image_dir.joinpath(img_fname) - img = Image.open(img_path).convert("RGB") - - # take first caption - caption = datum['captions'][0] - - return { - "img": img, - "caption": caption, - } - - def collate_fn(self, datum_list): - B = len(datum_list) - imgs = [datum['img'] for datum in datum_list] - images = self.img_transform(imgs, return_tensors="pt") - - captions = [datum['caption'] for datum in datum_list] - - text_tokens = self.tokenizer(captions, return_tensors="pt", padding=True) - batch = { - 'images': images, - 'captions': text_tokens, - } - return batch - - -def compute_similarity(image_features, text_features, bs = 1000): - # compute similarity - max_pairs = image_features.shape[0] - similarity_scores = torch.zeros(max_pairs, max_pairs) - for v in range(0, max_pairs, bs): - for t in range(0, max_pairs, bs): - # print('Processing Visual '+str(v)+' Text '+str(t), end='\r') - batch_visual_emb = image_features[v:v+bs] - batch_caption_emb = text_features[t:t+bs] - - logits = batch_visual_emb @ batch_caption_emb.t() - similarity_scores[v:v+bs,t:t+bs] = logits - - print('Done similarity') - return similarity_scores - -def compute_retrieval(a2b_sims, return_ranks=True): - """ - Args: - a2b_sims: Result of computing similarity between two sets of embeddings (emb1 @ emb2.T) - with shape (num_datapoints, num_datapoints). - - Returns: - Retrieval metrics for that similarity. - """ - npts = a2b_sims.shape[0] - ranks = np.zeros(npts) - top1 = np.zeros(npts) - # loop source embedding indices - for index in range(npts): - # get order of similarities to target embeddings - inds = np.argsort(a2b_sims[index])[::-1] - # find where the correct embedding is ranked - where = np.where(inds == index) - rank = where[0][0] - ranks[index] = rank - # save the top1 result as well - top1[index] = inds[0] - - # Compute metrics - r1 = 100.0 * len(np.where(ranks < 1)[0]) / len(ranks) - r5 = 100.0 * len(np.where(ranks < 5)[0]) / len(ranks) - r10 = 100.0 * len(np.where(ranks < 10)[0]) / len(ranks) - r50 = 100.0 * len(np.where(ranks < 50)[0]) / len(ranks) - medr = np.floor(np.median(ranks)) + 1 - meanr = ranks.mean() + 1 - - report_dict = {"r1": r1, "r5": r5, "r10": r10, "r50": r50, "medr": medr, "meanr": meanr, "sum": r1 + r5 + r10} - - if return_ranks: - return report_dict, (ranks, top1) - else: - return report_dict - - -if __name__ == '__main__': - - parser = argparse.ArgumentParser() - parser.add_argument('--coco_root', type=str, default="/nas-ssd/jmincho/datasets/COCO/") - parser.add_argument('--gt', action='store_true') - parser.add_argument('--gen_caption_path', type=str, default="./eval_results/clipRN50_cider_test.json") - args = parser.parse_args() - - model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32") - processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32") - - device = "cuda" - model = model.to(device) - model.eval() - print(f"Loaded CLIP at {device}") - - batch_size = 1000 - - dataset = COCODataset( - coco_root="/nas-ssd/jmincho/datasets/COCO/", - gen_caption_path=args.gen_caption_path, - is_gt=args.gt - ) - data_loader = DataLoader( - dataset, - batch_size=batch_size, - collate_fn=dataset.collate_fn, - shuffle=False, - num_workers=8) - - # fwd all samples - image_features = [] - text_features = [] - for batch_idx, batch in enumerate(tqdm(data_loader)): - # print('Evaluating batch {}/{}'.format(batch_idx, len(data_loader)), end="\r") - # images, texts = batch - - with torch.no_grad(): - images = batch["images"].to(device) - texts = batch["captions"].to(device) - - vision_outputs = model.vision_model(**batch['images']) - text_outputs = model.text_model(**batch['captions']) - - image_embeds = vision_outputs[1] - image_embeds = model.visual_projection(image_embeds) - - text_embeds = text_outputs[1] - text_embeds = model.text_projection(text_embeds) - - # normalized features - image_embeds = image_embeds / image_embeds.norm(dim=-1, keepdim=True) - text_embeds = text_embeds / text_embeds.norm(dim=-1, keepdim=True) - - text_features.append(text_embeds.detach().cpu()) - image_features.append(image_embeds.detach().cpu()) - - image_features = torch.cat(image_features, 0) - text_features = torch.cat(text_features, 0) - print('Done forward') - - # normalized features - image_features = image_features / image_features.norm(dim=-1, keepdim=True) - text_features = text_features / text_features.norm(dim=-1, keepdim=True) - - # if not single_caption: - # for cap_idx in range(text_features.shape[1]): - # similarity_scores = compute_similarity(image_features, text_features[:,cap_idx,:]) - # i2t_dict = compute_retrieval(similarity_scores.numpy()) - # t2i_dict = compute_retrieval(similarity_scores.t().numpy()) - # print(cap_idx, 'i2t', i2t_dict) - # print(cap_idx, 't2i', t2i_dict) - # else: - similarity_scores = compute_similarity(image_features, text_features) - i2t_dict = compute_retrieval(similarity_scores.numpy()) - t2i_dict = compute_retrieval(similarity_scores.t().numpy()) - print('i2t', i2t_dict) - print('t2i', t2i_dict) diff --git a/tools/eval_finecapeval.py b/tools/eval_finecapeval.py deleted file mode 100644 index 43916493adfb0736bc97589512f0b23c12154626..0000000000000000000000000000000000000000 --- a/tools/eval_finecapeval.py +++ /dev/null @@ -1,204 +0,0 @@ - -from tqdm import tqdm -from pprint import pprint -import pandas as pd -import argparse -import re -import json -import nltk -from nltk.tokenize import word_tokenize -from nltk.stem.porter import PorterStemmer -p_stemmer = PorterStemmer() - -# nltk.download('punkt') -# nltk.download('wordnet') -# nltk.download('stopwords') - -import language_evaluation -evaluator = language_evaluation.CocoEvaluator() - - -def nltk_process(text): - # Tokenization - nltk_tokenList = word_tokenize(text) - - # Stemming - nltk_stemedList = [] - for word in nltk_tokenList: - nltk_stemedList.append(p_stemmer.stem(word)) - - filtered_sentence = nltk_stemedList - - # Removing Punctuation - - tokens = [re.sub(r'[^a-zA-Z0-9]', '', tok) for tok in filtered_sentence] - - text = " ".join(tokens) - - return text - - -def calculate_finegrained_scores(pred_id2sent, id2caption, use_coco_eval=False): - if use_coco_eval: - n_total = 0 - refs = [] - hyps = [] - for id, gt_captions in id2caption.items(): - pred_sent = pred_id2sent[id] - - refs.append(gt_captions) - hyps.append(pred_sent) - - n_total += 1 - - print('caption') - results = evaluator.run_evaluation(hyps, refs) - pprint(results) - - n_total = 0 - total_score = 0 - for id, gt_phrases in id2background.items(): - pred_sent = pred_id2sent[id] - - score = 0 - n_phrases = len(gt_phrases) - - for gt_phrase in gt_phrases: - word_score = 0 - for gt_word in gt_phrase.split(): - if gt_word in pred_sent: - word_score += 1 - if len(gt_phrase.split()) > 0: - score += word_score / len(gt_phrase.split()) - - if n_phrases > 0: - score /= n_phrases - - total_score += score - n_total += 1 - print('background') -# print('# retrieved words:', n_retrieved) - print(f'Acc: {total_score / n_total * 100:.2f}') - - n_total = 0 - total_score = 0 - for id, gt_phrases in id2object.items(): - pred_sent = pred_id2sent[id] - - score = 0 - n_phrases = len(gt_phrases) - - for gt_phrase in gt_phrases: - word_score = 0 - for gt_word in gt_phrase.split(): - if gt_word in pred_sent: - word_score += 1 - if len(gt_phrase.split()) > 0: - score += word_score / len(gt_phrase.split()) - - if n_phrases > 0: - score /= n_phrases - - total_score += score - n_total += 1 - print('object') -# print('# retrieved words:', n_retrieved) - print(f'Acc: {total_score / n_total * 100:.2f}') - - n_total = 0 - total_score = 0 - for id, gt_phrases in id2relation.items(): - pred_sent = pred_id2sent[id] - - score = 0 - n_phrases = len(gt_phrases) - - for gt_phrase in gt_phrases: - word_score = 0 - for gt_word in gt_phrase.split(): - if gt_word in pred_sent: - word_score += 1 - if len(gt_phrase.split()) > 0: - score += word_score / len(gt_phrase.split()) - - if n_phrases > 0: - score /= n_phrases - - total_score += score - n_total += 1 - print('relation') -# print('# retrieved words:', n_retrieved) - print(f'Acc: {total_score / n_total * 100:.2f}') - - -if __name__ == '__main__': - parser = argparse.ArgumentParser() - parser.add_argument('--finecapeval_path', type=str, default="data/FineCapEval.csv") - parser.add_argument('--generated_id2caption', type=str, default="FineCapEval_results/mle.json") - args = parser.parse_args() - - df = pd.read_csv(args.finecapeval_path) - assert df.shape == (5000, 5) - - generated_id2caption = json.load(open(args.generated_id2caption, 'r')) - - print("Preprocessing GT FineCapEval data...") - id2caption = {} - id2background = {} - id2object = {} - id2relation = {} - - for row in tqdm(df.itertuples(), total=len(df)): - - id = row.image.split('.')[0] - caption = row.caption - background = row.background - object = row.object - relation = row.relation - - if not isinstance(caption, str): - continue - if not isinstance(background, str): - continue - if not isinstance(object, str): - continue - if not isinstance(relation, str): - continue - - if id not in id2caption: - id2caption[id] = [] - id2background[id] = [] - id2object[id] = [] - id2relation[id] = [] - - id2caption[id].append(caption) - - phrases = [] - for phrase in background.lower().split('\;'): - if len(phrase) > 1: - phrase = nltk_process(phrase) - phrases.append(phrase) - id2background[id].extend(phrases) - - phrases = [] - for phrase in object.lower().split('\;'): - if len(phrase) > 1: - phrase = nltk_process(phrase) - phrases.append(phrase) - id2object[id].extend(phrases) - - phrases = [] - for phrase in relation.lower().split('\;'): - if len(phrase) > 1: - phrase = nltk_process(phrase) - phrases.append(phrase) - id2relation[id].extend(phrases) - - print("Calculating scores...") - calculate_finegrained_scores( - generated_id2caption, - id2caption, - use_coco_eval=True) - - - diff --git a/tools/finecapeval_inference.py b/tools/finecapeval_inference.py deleted file mode 100644 index 260b083e00df7c9b2349be23fd2a09591dec3f2b..0000000000000000000000000000000000000000 --- a/tools/finecapeval_inference.py +++ /dev/null @@ -1,186 +0,0 @@ -import sys -import torch -import torch.nn as nn -import torch.nn.functional as F -import torch.optim as optim - -import numpy as np - -import time -import os -from collections import defaultdict -import json - -import captioning.utils.opts as opts -import captioning.models as models -from captioning.data.pth_loader import CaptionDataset -import captioning.utils.eval_utils as eval_utils -# import captioning.utils.vizwiz_eval_utils as vizwiz_eval_utils -import captioning.utils.misc as utils -from captioning.utils.rewards import init_scorer, get_self_critical_reward -from captioning.modules.loss_wrapper import LossWrapper - -import pytorch_lightning as pl - - -class ModelCheckpoint(pl.callbacks.ModelCheckpoint): - - def on_keyboard_interrupt(self, trainer, pl_module): - # Save model when keyboard interrupt - filepath = os.path.join(self.dirpath, self.prefix + 'interrupt.ckpt') - self._save_model(filepath) - - -if __name__ == '__main__': - - device = 'cuda' - - import argparse - parser = argparse.ArgumentParser() - parser.add_argument('--reward', type=str, default='mle') - args = parser.parse_args() - - if args.reward == 'mle': - cfg = f'configs/phase1/fg_clipRN50_{args.reward}.yml' - else: - cfg = f'configs/phase2/fg_clipRN50_{args.reward}.yml' - - print("Loading cfg from", cfg) - - opt = opts.parse_opt(parse=False, cfg=cfg) - - dataset = CaptionDataset(opt) - - opt.vocab_size = dataset.vocab_size - opt.seq_length = dataset.seq_length - - opt.batch_size = 40 - - opt.vocab = dataset.get_vocab() - - model = models.setup(opt) - del opt.vocab - - ckpt_path = opt.checkpoint_path + '-last.ckpt' - - print("Loading checkpoint from", ckpt_path) - raw_state_dict = torch.load( - ckpt_path, - map_location=device) - - strict = True - - state_dict = raw_state_dict['state_dict'] - - if '_vocab' in state_dict: - model.vocab = utils.deserialize(state_dict['_vocab']) - del state_dict['_vocab'] - elif strict: - raise KeyError - if '_opt' in state_dict: - saved_model_opt = utils.deserialize(state_dict['_opt']) - del state_dict['_opt'] - # Make sure the saved opt is compatible with the curren topt - need_be_same = ["caption_model", - "rnn_type", "rnn_size", "num_layers"] - for checkme in need_be_same: - if getattr(saved_model_opt, checkme) in ['updown', 'topdown'] and \ - getattr(opt, checkme) in ['updown', 'topdown']: - continue - assert getattr(saved_model_opt, checkme) == getattr( - opt, checkme), "Command line argument and saved model disagree on '%s' " % checkme - elif strict: - raise KeyError - res = model.load_state_dict(state_dict, strict) - print(res) - - opt.use_grammar = False - - lw_model = LossWrapper(model, opt) - - split = 'test' - - print("Building dataloader...") - - test_dataset = torch.utils.data.Subset( - dataset, - dataset.split_ix[split] - ) - test_loader = torch.utils.data.DataLoader( - test_dataset, - batch_size=opt.batch_size, - shuffle=False, - num_workers=4, - drop_last=False, - collate_fn=dataset.collate_func - ) - - eval_kwargs = {'dataset': opt.input_json} - eval_kwargs.update(vars(opt)) - - verbose = eval_kwargs.get('verbose', True) - verbose_beam = eval_kwargs.get('verbose_beam', 0) - verbose_loss = eval_kwargs.get('verbose_loss', 1) - # num_images = eval_kwargs.get('num_images', eval_kwargs.get('val_images_use', -1)) - # lang_eval = eval_kwargs.get('language_eval', 0) - dataset = eval_kwargs.get('dataset', 'coco') - beam_size = eval_kwargs.get('beam_size', 1) - sample_n = eval_kwargs.get('sample_n', 1) - remove_bad_endings = eval_kwargs.get('remove_bad_endings', 0) - - crit = lw_model.crit - - model = model.to(device) - - from tqdm import tqdm - - test_id2sent = {} - - model.eval() - - print("running inference...") - - for data in tqdm(test_loader): - with torch.no_grad(): - # forward the model to get loss - tmp = [data['fc_feats'], data['att_feats'], - data['labels'], data['masks'], data['att_masks']] - tmp = [d.to(device) if isinstance(d, torch.Tensor) else d for d in tmp] - - fc_feats, att_feats, labels, masks, att_masks = tmp - - loss = crit(model(fc_feats, att_feats, - labels[..., :-1], att_masks), labels[..., 1:], masks[..., 1:]) - - # forward the model to also get generated samples for each image - # Only leave one feature for each image, in case duplicate sample - tmp_eval_kwargs = eval_kwargs.copy() - tmp_eval_kwargs.update({'sample_n': 1}) - seq, seq_logprobs = model( - fc_feats, att_feats, att_masks, opt=tmp_eval_kwargs, mode='sample') - seq = seq.data - entropy = - (F.softmax(seq_logprobs, dim=2) * - seq_logprobs).sum(2).sum(1) / ((seq > 0).to(seq_logprobs).sum(1)+1) - perplexity = - \ - seq_logprobs.gather(2, seq.unsqueeze(2)).squeeze( - 2).sum(1) / ((seq > 0).to(seq_logprobs).sum(1)+1) - - # Print beam search - if beam_size > 1 and verbose_beam: - for i in range(fc_feats.shape[0]): - print('\n'.join([utils.decode_sequence(model.vocab, _[ - 'seq'].unsqueeze(0))[0] for _ in model.done_beams[i]])) - print('--' * 10) - sents = utils.decode_sequence(model.vocab, seq) - - for d, sent in zip(data['infos'], sents): - test_id2sent[d['id']] = sent - - res_path = f'FineCapEval_results/clipRN50_{args.reward}.json' - - print("Results save at {}".format(res_path)) - - with open(res_path, 'w') as f: - json.dump(test_id2sent, f) - - diff --git a/tools/train_pl.py b/tools/train_pl.py deleted file mode 100644 index 48ac2d0cf68466bd0e39f9c994056063a0529f27..0000000000000000000000000000000000000000 --- a/tools/train_pl.py +++ /dev/null @@ -1,709 +0,0 @@ -import torch -import torch.nn as nn -import torch.nn.functional as F -import torch.optim as optim - -import numpy as np - -import time -import os -from collections import defaultdict - -import captioning.utils.opts as opts -import captioning.models as models -from captioning.data.pth_loader import CaptionDataset -import captioning.utils.eval_utils as eval_utils -import captioning.utils.misc as utils -from captioning.utils.rewards import init_scorer, get_self_critical_reward -from captioning.modules.loss_wrapper import LossWrapper - -import pytorch_lightning as pl - -import detectron2.utils.comm as d2comm -from detectron2.utils.env import seed_all_rng -seed_all_rng(1234) - - -class LitModel(pl.LightningModule): - def __init__(self, opt): - super().__init__() - self.opt = opt - # Intilaize dataset - self.dataset = CaptionDataset(opt) - opt.vocab_size = self.dataset.vocab_size - opt.seq_length = self.dataset.seq_length - self.batch_size = opt.batch_size - - # Build model - opt.vocab = self.dataset.get_vocab() - model = models.setup(opt) - # print(model) - del opt.vocab - - # wrapper with loss in it. - lw_model = LossWrapper(model, opt) - - self.model = model - self.lw_model = lw_model - - self.struc_flag = None - self.sc_flag = None - - # if self.opt.use_clipscore: - # if self.opt.use_clipscore or os.getenv('EVALUATE', '0') == '1': - # if CLIP-S+Grammar is used in reward -> Launch another CLIP-S where parameter is unchanged - if getattr(self.opt, 'use_grammar', False): - from captioning.utils.clipscore import CLIPScore - self.val_clipscore_model = CLIPScore( - mode=opt.clipscore_mode, use_grammar=False) - for p in self.val_clipscore_model.parameters(): - p.requires_grad = False - else: - if self.lw_model.clipscore_model is not None: - self.val_clipscore_model = self.lw_model.clipscore_model - else: - from captioning.utils.clipscore import CLIPScore - self.val_clipscore_model = CLIPScore( - mode=opt.clipscore_mode, use_grammar=False) - for p in self.val_clipscore_model.parameters(): - p.requires_grad = False - self.val_clipscore_model.eval() - - # BERTSCORE - from bert_score import BERTScorer - self.bert_scorer = BERTScorer( - lang="en", - # rescale_with_baseline=True, - rescale_with_baseline=False, - device='cpu' - ) - - def forward(self, *args, **kwargs): - """ - I hate this design. Never pretend it as a nn.Module - """ - raise NotImplementedError - - def train_dataloader(self): - train_dataset = torch.utils.data.Subset( - self.dataset, - self.dataset.split_ix['train'] - ) - - train_loader = torch.utils.data.DataLoader( - dataset=train_dataset, - batch_size=self.batch_size, - shuffle=True, - num_workers=4, - collate_fn=self.dataset.collate_func - ) - return train_loader - - def val_dataloader(self, split='val'): - val_dataset = torch.utils.data.Subset( - self.dataset, - self.dataset.split_ix[split] - ) - val_loader = torch.utils.data.DataLoader( - val_dataset, - batch_size=self.batch_size, - shuffle=False, - num_workers=4, - drop_last=False, - collate_fn=self.dataset.collate_func - ) - return val_loader - - def test_dataloader(self): - return self.val_dataloader('test') - - def training_step(self, data, batch_idx): - sc_flag, struc_flag = self.sc_flag, self.struc_flag - - tmp = [data['fc_feats'], data['att_feats'], - data['labels'], data['masks'], data['att_masks']] - fc_feats, att_feats, labels, masks, att_masks = tmp - if int(os.getenv('M2_cider', '0')) != 0: - data['gts'] = data['rawgts'] - - if self.opt.use_clipscore: - clip_vis_feats = data['clip_vis_feats'] - model_out = self.lw_model(fc_feats, att_feats, labels, masks, att_masks, - data['gts'], torch.arange(0, len(data['gts'])), sc_flag, struc_flag, - clip_vis_feats=clip_vis_feats) - else: - model_out = self.lw_model(fc_feats, att_feats, labels, masks, att_masks, - data['gts'], torch.arange(0, len(data['gts'])), sc_flag, struc_flag) - loss = model_out['loss'] - - data_time = self.trainer.profiler.recorded_durations["get_train_batch"][-1] - data_time = torch.tensor(data_time) - - logger_logs = model_out.copy() - # if struc_flag or sc_flag: - # logger_logs['reward'] = model_out['reward'].mean() - # logger_logs['reward_var'] = model_out['reward'].var(1).mean() - if struc_flag or sc_flag: - logger_logs['reward'] = model_out['reward'].mean() - for k in ['CLIP-S', 'RefCLIP-S', 'CIDEr', 'grammar_reward']: - if k in model_out: - logger_logs[k] = model_out[k] - if struc_flag: - logger_logs['reward_var'] = model_out['reward'].var(1).mean() - - logger_logs['scheduled_sampling_prob'] = torch.tensor( - self.model.ss_prob) - # logger_logs['training_loss'] = loss - logger_logs['loss'] = loss - logger_logs['data_time'] = data_time - - # UserWarning: The {progress_bar:dict keyword} was deprecated in 0.9.1 and will be removed in 1.0.0 - # Please use self.log(...) inside the lightningModule instead. - - # # log on a step or aggregate epoch metric to the logger and/or progress bar - # # (inside LightningModule) - # self.log('train_loss', loss, on_step=True, on_epoch=True, prog_bar=True) - # warnings.warn(*args, **kwargs) - # UserWarning: The {log:dict keyword} was deprecated in 0.9.1 and will be removed in 1.0.0 - # Please use self.log(...) inside the lightningModule instead. - - # output = { - # 'loss': loss, - # 'log': logger_logs, - # 'progress_bar': {'data_time': data_time} - # } - - for k, v in logger_logs.items(): - if k in ['reward', 'reward_var', 'data_time', 'CLIP-S', 'RefCLIP-S', 'CIDEr', 'grammar_reward']: - self.log('train/'+k, v, prog_bar=True) - else: - self.log('train/'+k, v) - - return loss - - def validation_step(self, data, batch_idx): - model = self.model - crit = self.lw_model.crit - - opt = self.opt - eval_kwargs = {'dataset': opt.input_json} - eval_kwargs.update(vars(opt)) - - # CLIPScore - use_grammar = getattr(self.opt, 'use_grammar', False) - joint_out = getattr(self.opt, 'joint_out', False) - - verbose = eval_kwargs.get('verbose', True) - verbose_beam = eval_kwargs.get('verbose_beam', 0) - verbose_loss = eval_kwargs.get('verbose_loss', 1) - # num_images = eval_kwargs.get('num_images', eval_kwargs.get('val_images_use', -1)) - # lang_eval = eval_kwargs.get('language_eval', 0) - dataset = eval_kwargs.get('dataset', 'coco') - beam_size = eval_kwargs.get('beam_size', 1) - sample_n = eval_kwargs.get('sample_n', 1) - remove_bad_endings = eval_kwargs.get('remove_bad_endings', 0) - # Use this nasty way to make other code clean since it's a global configuration - os.environ["REMOVE_BAD_ENDINGS"] = str(remove_bad_endings) - - predictions = [] - n_predictions = [] - - loss = torch.tensor(0) - if data.get('labels', None) is not None and verbose_loss: - # forward the model to get loss - tmp = [data['fc_feats'], data['att_feats'], - data['labels'], data['masks'], data['att_masks']] - fc_feats, att_feats, labels, masks, att_masks = tmp - - loss = crit(model(fc_feats, att_feats, - labels[..., :-1], att_masks), labels[..., 1:], masks[..., 1:]) - - # forward the model to also get generated samples for each image - # Only leave one feature for each image, in case duplicate sample - tmp_eval_kwargs = eval_kwargs.copy() - tmp_eval_kwargs.update({'sample_n': 1}) - seq, seq_logprobs = model( - fc_feats, att_feats, att_masks, opt=tmp_eval_kwargs, mode='sample') - seq = seq.data - entropy = - (F.softmax(seq_logprobs, dim=2) * - seq_logprobs).sum(2).sum(1) / ((seq > 0).to(seq_logprobs).sum(1)+1) - perplexity = - \ - seq_logprobs.gather(2, seq.unsqueeze(2)).squeeze( - 2).sum(1) / ((seq > 0).to(seq_logprobs).sum(1)+1) - - # Print beam search - if beam_size > 1 and verbose_beam: - for i in range(fc_feats.shape[0]): - print('\n'.join([utils.decode_sequence(model.vocab, _[ - 'seq'].unsqueeze(0))[0] for _ in model.done_beams[i]])) - print('--' * 10) - sents = utils.decode_sequence(model.vocab, seq) - - # if self.opt.use_clipscore or os.getenv('EVALUATE', '0') == '1': - # text_feat = self.lw_model.clipscore_model.text_extract(sents) - text_feat = self.val_clipscore_model.text_extract(sents, proj_norm=False) - - text_cont_feat = self.val_clipscore_model.clip_model.text_projection(text_feat) - text_cont_feat = text_cont_feat / text_cont_feat.norm(dim=-1, keepdim=True) - - vis_feat = data['clip_vis_feats'] - # if self.opt.clipscore_mode == 'clip_s': - # clip_s = self.val_clipscore_model(text_feat=text_cont_feat, img_feat=vis_feat, mode='clip_s') - - # elif self.opt.clipscore_mode == 'refclip_s': - clip_s = self.val_clipscore_model(text_feat=text_cont_feat, img_feat=vis_feat, mode='clip_s') - # ref_text = utils.decode_sequence(model.vocab, data['gts']) - - gt_indices = torch.arange(0, len(data['gts'])) - data_gts = [data['gts'][_] for _ in gt_indices.tolist()] - - B = len(data_gts) - - gts = [] - gts_valid_mask = [] - max_n_refs = max([len(_gts) for _gts in data_gts]) - for i in range(len(data_gts)): - _gts = utils.decode_sequence(model.vocab, data_gts[i]) - # pad references - n_ref = len(_gts) - _gts.extend([''] * (max_n_refs - n_ref)) - gts.extend(_gts) - gts_valid_mask.extend([1] * n_ref + [0] * (max_n_refs - n_ref)) - assert len(gts) == B * max_n_refs - assert len(gts_valid_mask) == B * max_n_refs - - ref_text = gts - ref_text_mask = gts_valid_mask - - refclip_s = self.val_clipscore_model( - text_feat=text_cont_feat, img_feat=vis_feat, - ref_text=ref_text, ref_text_mask=ref_text_mask, mode='refclip_s') - - # use_grammar = getattr(self.opt, 'use_grammar', False) - # joint_out = getattr(self.opt, 'joint_out', False) - if use_grammar and not joint_out: - with torch.no_grad(): - # grammar_logit = self.val_clipscore_model.grammar_score_head(text_feat.view(-1, 512)) - grammar_logit = self.lw_model.clipscore_model.grammar_score_head(text_feat.view(-1, 512)) - grammar_prob = torch.softmax(grammar_logit, dim=-1)[:, 1] - - - # BERTScore - if next(self.bert_scorer._model.parameters()).device != self.device: - self.bert_scorer._model.to(self.device) - self.bert_scorer.device = self.device - - - # [B*K] -> [B, K] - ref_text_per_example = [] - for i in range(B): - ref_text_list_example = [] - for k in range(max_n_refs): - ref = ref_text[i * max_n_refs + k] - if len(ref) > 0: - ref_text_list_example.append(ref) - # assert len(ref_text_list_example) == max_n_refs - ref_text_per_example.append(ref_text_list_example) - assert len(ref_text_per_example) == B - - P, R, F1 = self.bert_scorer.score( - sents, - ref_text_per_example, - ) - bertscore_f1 = F1 - # print('Example 5:') - # for i in range(5): - # print('Generated:', sents[i]) - # print('ref_text:', ref_text_per_example[i]) - # print('BERT-Score:', F1[i].item()) - - - for k, sent in enumerate(sents): - entry = {'image_id': data['infos'][k]['id'], 'caption': sent, - 'perplexity': perplexity[k].item(), 'entropy': entropy[k].item()} - if self.opt.use_clipscore or os.getenv('EVALUATE', '0') == '1': - # if self.opt.clipscore_mode == 'clip_s': - # entry['clipscore'] = clipscore[k].item() - # entry['CLIP-S'] = clip_s[k].item() - # elif self.opt.clipscore_mode == 'refclip_s': - entry['CLIP-S'] = clip_s[k].item() - entry['RefCLIP-S'] = refclip_s[k].item() - - if use_grammar and not joint_out: - entry['grammar_prob'] = grammar_prob[k].item() - - # BERT-S - entry['BERT-S'] = bertscore_f1[k].item() - - if eval_kwargs.get('dump_path', 0) == 1: - entry['file_name'] = data['infos'][k]['file_path'] - predictions.append(entry) - if eval_kwargs.get('dump_images', 0) == 1: - # dump the raw image to vis/ folder - cmd = 'cp "' + os.path.join(eval_kwargs['image_root'], data['infos'][k]['file_path']) + \ - '" vis/imgs/img' + \ - str(len(predictions)) + '.jpg' # bit gross - print(cmd) - os.system(cmd) - - if verbose: - print('image %s: %s' % - (entry['image_id'], entry['caption'])) - - if sample_n > 1: - eval_utils.eval_split_n(model, n_predictions, [ - fc_feats, att_feats, att_masks, data], eval_kwargs) - - output = { - # 'val_loss': loss, - 'loss': loss, - 'predictions': predictions, - 'n_predictions': n_predictions, - } - return output - - def test_step(self, *args, **kwargs): - return self.validation_step(*args, **kwargs) - - def validation_epoch_end(self, outputs, split='val'): - outputs = d2comm.gather(outputs) - # master node - if d2comm.is_main_process(): - assert self.trainer.node_rank == 0 and self.trainer.local_rank == 0 - outputs = sum(outputs, []) - - opt = self.opt - # val_loss_mean = sum([_['val_loss'] - # val_loss_mean = sum([_['val_loss'].cpu() - val_loss_mean = sum([_['loss'].cpu() - for _ in outputs]) / len(outputs) - - predictions = sum([_['predictions'] for _ in outputs], []) - if len(outputs[0]['n_predictions']) != 0: - n_predictions = sum([_['n_predictions'] for _ in outputs], []) - else: - n_predictions = [] - - lang_stats = None - if len(n_predictions) > 0 and 'perplexity' in n_predictions[0]: - n_predictions = sorted( - n_predictions, key=lambda x: x['perplexity']) - - if not os.path.isdir('eval_results'): - os.mkdir('eval_results') - torch.save((predictions, n_predictions), os.path.join( - 'eval_results/', '.saved_pred_' + opt.id + '_' + split + '.pth')) - - if opt.language_eval: - lang_stats = eval_utils.language_eval( - opt.input_json, predictions, n_predictions, vars(opt), split) - - if opt.reduce_on_plateau: - optimizer = self.trainer.optimizers[0] - if 'CIDEr' in lang_stats: - optimizer.scheduler_step(-lang_stats['CIDEr']) - else: - optimizer.scheduler_step(val_loss_mean) - - # out = { - # 'val_loss': val_loss_mean - # } - out = { - 'loss': val_loss_mean - } - out.update(lang_stats) - # out['to_monitor'] = lang_stats['CIDEr'] if lang_stats is not None else -val_loss_mean - if self.opt.use_clipscore or os.getenv('EVALUATE', '0') == '1': - # if self.opt.clipscore_mode == 'clip_s': - # out['clipscore'] = sum([p['clipscore'] for p in predictions]) / len(predictions) - # print('CLIPScore', out['clipscore']) - # out['CLIP-S'] = sum([p['CLIP-S'] for p in predictions]) / len(predictions) - # print('CLIP-S', out['CLIP-S']) - # elif self.opt.clipscore_mode == 'refclip_s': - out['CLIP-S'] = sum([p['CLIP-S'] for p in predictions]) / len(predictions) - print('CLIP-S', out['CLIP-S']) - - out['RefCLIP-S'] = sum([p['RefCLIP-S'] for p in predictions]) / len(predictions) - print('RefCLIP-S', out['RefCLIP-S']) - - if getattr(self.opt, 'use_grammar', False) and not getattr(self.opt, 'joint_out', False): - out['grammar_prob'] = sum([p['grammar_prob'] for p in predictions]) / len(predictions) - print('grammar_prob', out['grammar_prob']) - - out['BERT-S'] = sum([p['BERT-S'] for p in predictions]) / len(predictions) - print('BERT-S', out['BERT-S']) - else: - out = {} - - out = d2comm.all_gather(out)[0] # Only the one from master node - assert len(out) > 0 # make sure the head has index 0 - - # must all be tensors - out = {k: torch.tensor(v) if not torch.is_tensor( - v) else v for k, v in out.items()} - - # return { - # 'progress_bar': {'val_loss': out['val_loss']}, - # 'log': out, - # } - for k, v in out.items(): - # if k in ['loss', 'clipscore', 'RefCLIP-S', 'CIDEr']: - # if split != 'test': - # self.log(f'{split}/{k}', v, prog_bar=True) - # elif k == 'to_monitor': - # if split != 'test': - # self.log(f'{split}/{k}', v) - # else: - self.log(f'{split}/{k}', v) - - def test_epoch_end(self, outputs): - # out = self.validation_epoch_end(outputs, 'test') - # out['progress_bar'] = { - # # 'test_loss': out['progress_bar']['val_loss'] - # 'test_loss': out['progress_bar']['loss'] - # } - # out['log']['test_loss'] = out['log']['val_loss'] - # del out['log']['val_loss'] - # del out['log']['to_monitor'] - - # out['log'] = {'test_'+k if 'test' not in k else k:v \ - # for k,v in out['log'].items()} - - # return out - self.validation_epoch_end(outputs, 'test') - - def configure_optimizers(self): - opt = self.opt - model = self.model - - parameters = [p for p in model.parameters() if p.requires_grad] - - if opt.noamopt: - # assert opt.caption_model in ['transformer', 'bert', 'm2transformer'], 'noamopt can only work with transformer' - optimizer = utils.get_std_opt( - model, optim_func=opt.optim, factor=opt.noamopt_factor, warmup=opt.noamopt_warmup) - elif opt.reduce_on_plateau: - # optimizer = utils.build_optimizer(model.parameters(), opt) - optimizer = utils.build_optimizer(parameters, opt) - optimizer = utils.ReduceLROnPlateau(optimizer, - factor=opt.reduce_on_plateau_factor, - patience=opt.reduce_on_plateau_patience) - else: - # optimizer = utils.build_optimizer(model.parameters(), opt) - optimizer = utils.build_optimizer(parameters, opt) - return [optimizer], [] - - def optimizer_step(self, epoch, batch_idx, optimizer, - optimizer_idx, *args, **kwargs): - # warm up lr - opt = self.opt - iteration = self.trainer.global_step - if opt.use_warmup and (iteration < opt.noamopt_warmup): - opt.current_lr = opt.learning_rate * \ - (iteration+1) / opt.noamopt_warmup - utils.set_lr(optimizer, opt.current_lr) - - super().optimizer_step(epoch, batch_idx, optimizer, - optimizer_idx, *args, **kwargs) - - def state_dict(self): - """ - Save the model state dict as well as opt and vocab - """ - state_dict = self.model.state_dict() - device = next(iter(state_dict.values())).device - assert '_vocab' not in state_dict and '_opt' not in state_dict, 'Just in case' - state_dict.update({ - '_vocab': utils.serialize_to_tensor(self.model.vocab).to(device), - '_opt': utils.serialize_to_tensor(self.opt).to(device) - }) - return state_dict - - def load_state_dict(self, state_dict=None, strict=True): - if '_vocab' in state_dict: - self.model.vocab = utils.deserialize(state_dict['_vocab']) - del state_dict['_vocab'] - # elif strict: - # raise KeyError - if '_opt' in state_dict: - saved_model_opt = utils.deserialize(state_dict['_opt']) - del state_dict['_opt'] - opt = self.opt - # Make sure the saved opt is compatible with the curren topt - need_be_same = ["caption_model", - "rnn_type", "rnn_size", "num_layers"] - for checkme in need_be_same: - if getattr(saved_model_opt, checkme) in ['updown', 'topdown'] and \ - getattr(opt, checkme) in ['updown', 'topdown']: - continue - assert getattr(saved_model_opt, checkme) == getattr( - opt, checkme), "Command line argument and saved model disagree on '%s' " % checkme - # elif strict: - # raise KeyError - self.model.load_state_dict(state_dict, strict) - - -class OnEpochStartCallback(pl.Callback): - - def on_epoch_start(self, trainer, pl_module): - # Update lr/training stage/scheduled sampling prob etc. - opt = pl_module.opt - model = pl_module.model - epoch = trainer.current_epoch - optimizer = trainer.optimizers[0] - - if not opt.noamopt and not opt.reduce_on_plateau: - # Assign the learning rate - if epoch > opt.learning_rate_decay_start and opt.learning_rate_decay_start >= 0: - frac = ( - epoch - opt.learning_rate_decay_start) // opt.learning_rate_decay_every - decay_factor = opt.learning_rate_decay_rate ** frac - opt.current_lr = opt.learning_rate * decay_factor - else: - opt.current_lr = opt.learning_rate - utils.set_lr(optimizer, opt.current_lr) # set the decayed rate - # Assign the scheduled sampling prob - if epoch > opt.scheduled_sampling_start and opt.scheduled_sampling_start >= 0: - frac = ( - epoch - opt.scheduled_sampling_start) // opt.scheduled_sampling_increase_every - opt.ss_prob = min(opt.scheduled_sampling_increase_prob * - frac, opt.scheduled_sampling_max_prob) - model.ss_prob = opt.ss_prob - - # If start self critical training - if opt.self_critical_after != -1 and epoch >= opt.self_critical_after: - sc_flag = True - init_scorer(opt.cached_tokens) - else: - sc_flag = False - - # If start structure loss training - if opt.structure_after != -1 and epoch >= opt.structure_after: - struc_flag = True - init_scorer(opt.cached_tokens) - else: - struc_flag = False - - pl_module.struc_flag = struc_flag - pl_module.sc_flag = sc_flag - - -class ModelCheckpoint(pl.callbacks.ModelCheckpoint): - - def on_keyboard_interrupt(self, trainer, pl_module): - # Save model when keyboard interrupt - filepath = os.path.join(self.dirpath, self.prefix + 'interrupt.ckpt') - self._save_model(filepath) - - -opt = opts.parse_opt() - -checkpoint_callback = ModelCheckpoint( - filepath=opt.checkpoint_path, - # dirpath=opt.checkpoint_path, - save_last=True, - save_top_k=1, - verbose=True, - # monitor='to_monitor', - # monitor='val/to_monitor', - monitor='val/CIDEr', - mode='max', - # prefix=opt.id+'_', - prefix=opt.id, - # filename=f'{opt.id}_', -) - -verbose = True -# import torch -# if torch.cuda.current_device() in [0, -1]: -if 'LOCAL_RANK' in os.environ and os.environ['LOCAL_RANK'] != '0': - verbose = False - -if verbose: - print(opt) - print(""" - val_image_use, - save_checkpoint_very - save_every_epoch, - save_history-ckpt will be ignored. - """) - -# Lightning defines batch size as batch size per gpu -assert opt.batch_size % torch.cuda.device_count() == 0 -opt.batch_size = opt.batch_size // torch.cuda.device_count() - -# If resume from last checkpoint -# if opt.start_from is not None and os.path.isfile(os.path.join(opt.start_from, f'{opt.id}_last.ckpt')): -# resume_from = os.path.join(opt.start_from, f'{opt.id}_last.ckpt') -if opt.start_from is not None: - resume_from = os.path.join(opt.start_from, f'{opt.id}-last.ckpt') - if os.path.isfile(resume_from): - if verbose: - print('Loading checkpoint from', resume_from) - else: - print("Checkpoint not found:", resume_from) - resume_from = None -else: - resume_from = None - -from pytorch_lightning.loggers import WandbLogger -wandb_logger = WandbLogger( - project='CLIP-ViL-COCOCaption', - name=opt.id, -) - -if verbose: - wandb_logger.experiment.config.update(opt) - from pathlib import Path - import glob - import wandb - # src_dir = Path(__file__).resolve().parent.parent - glob_str = "**/*.py" - base_path = './' - wandb.save(glob_str=glob_str, base_path=base_path) - - # code = wandb.Artifact('project-source', type='code') - # for path in glob.glob('**/*.py', recursive=True): - # code.add_file(path, name='source/'+path) - # print(path) - # wandb.run.use_artifact(code) - - - - -lit = LitModel(opt) -# warning grad_clip_mode is ignored. -trainer = pl.Trainer( - callbacks=[ - OnEpochStartCallback(), - # pl.callbacks.lr_logger.LearningRateLogger() - pl.callbacks.LearningRateMonitor() - ], - default_root_dir=opt.checkpoint_path, - resume_from_checkpoint=resume_from, - distributed_backend='ddp', - check_val_every_n_epoch=1, - max_epochs=opt.max_epochs, - gradient_clip_val=opt.grad_clip_value, - gpus=torch.cuda.device_count(), - checkpoint_callback=checkpoint_callback, - log_gpu_memory='min_max', - # log_save_interval=opt.losses_log_every, - log_every_n_steps=opt.losses_log_every, - profiler=True, - # profiler='simple', - # row_log_interval=10, # what is it? - flush_logs_every_n_steps=10, - num_sanity_val_steps=0, - # val_check_interval=0.01, - # limit_train_batches=500, - # progress_bar_refresh_rate=0, - # fast_dev_run=True, - precision=opt.precision, - logger=wandb_logger -) - -if os.getenv('EVALUATE', '0') == '1': - trainer.test(lit) -else: - trainer.fit(lit)