from __future__ import absolute_import from __future__ import division from __future__ import print_function import json, math import numpy as np import os, sys from six.moves import cPickle from sys import path sys.path.insert(0, os.getcwd()) sys.path.insert(0, 'captioning/') # print('relative captioning is called') import captioning.utils.opts as opts import captioning.models as models from captioning.data.dataloader import * from captioning.data.dataloaderraw import * import argparse import captioning.utils.misc as utils import torch import skimage.io from torch.autograd import Variable from torchvision import transforms as trn preprocess = trn.Compose([ # trn.ToTensor(), trn.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ]) from captioning.utils.resnet_utils import myResnet from captioning.utils.resnet_utils import ResNetBatch import captioning.utils.resnet as resnet import wget import tempfile class object: def __init__(self): self.input_fc_dir = '' self.input_json = '' self.batch_size = '' self.id = '' self.sample_max = 1 self.cnn_model = 'resnet101' self.model = '' self.language_eval = 0 self.beam_size = 1 self.temperature = 1.0 return class Captioner(): def __init__(self, is_relative=True, model_path=None, image_feat_params=None, data_type=None, load_resnet=True, diff_feat=None): opt = object() if image_feat_params==None: image_feat_params = {} image_feat_params['model'] = 'resnet101' image_feat_params['model_root'] = '' image_feat_params['att_size'] = 7 # inputs specific to shoe dataset infos_path = os.path.join(model_path, 'infos_best.pkl') model_path = os.path.join(model_path, 'model_best.pth') opt.infos_path = infos_path opt.model_path = model_path opt.beam_size = 1 opt.load_resnet = load_resnet # load pre-trained model, adjusting if URL if opt.infos_path.startswith("http:") or opt.infos_path.startswith("https:"): # create a folder to store the checkpoints for downloading if not os.path.exists('./checkpoints_usersim'): os.mkdir('./checkpoints_usersim') checkpoint_path = os.path.join('./checkpoints_usersim', data_type) if not os.path.exists(checkpoint_path): os.mkdir(checkpoint_path) # set the location for infos infos_loc = os.path.join(checkpoint_path, 'infos_best.pkl') if not os.path.exists(infos_loc): try: wget.download(opt.infos_path, infos_loc) except Exception as err: print(f"[{err}]") else: infos_loc = infos_path if opt.model_path.startswith("http:") or opt.model_path.startswith("https:"): # create a folder to store the checkpoints for downloading if not os.path.exists('./checkpoints_usersim'): os.mkdir('./checkpoints_usersim') checkpoint_path = os.path.join('./checkpoints_usersim', data_type) if not os.path.exists(checkpoint_path): os.mkdir(checkpoint_path) # set the location for models model_loc = os.path.join(checkpoint_path, 'model_best.pth') if not os.path.exists(model_loc): try: wget.download(opt.model_path, model_loc) except Exception as err: print(f"[{err}]") opt.model = model_loc else: opt.model = model_path if os.path.exists(infos_loc): # load existing infos with open(infos_loc, 'rb') as f: infos = cPickle.load(f) self.caption_model = infos["opt"].caption_model # override and collect parameters if len(opt.input_fc_dir) == 0: opt.input_fc_dir = infos['opt'].input_fc_dir opt.input_att_dir = infos['opt'].input_att_dir opt.input_label_h5 = infos['opt'].input_label_h5 if len(opt.input_json) == 0: opt.input_json = infos['opt'].input_json if opt.batch_size == 0: opt.batch_size = infos['opt'].batch_size if len(opt.id) == 0: opt.id = infos['opt'].id ignore = ["id", "batch_size", "beam_size", "start_from", "language_eval", "model"] for k in vars(infos['opt']).keys(): if k not in ignore: if k in vars(opt): assert vars(opt)[k] == vars(infos['opt'])[k], k + ' option not consistent' else: vars(opt).update({k: vars(infos['opt'])[k]}) # copy over options from model vocab = infos['vocab'] # ix -> word mapping # print('opt:', opt) # Setup the model opt.vocab = vocab model = models.setup(opt) del opt.vocab if torch.cuda.is_available(): model.load_state_dict(torch.load(opt.model)) model.cuda() else: model.load_state_dict(torch.load(opt.model, map_location={'cuda:0': 'cpu'})) model.eval() self.is_relative = is_relative self.model = model self.vocab = vocab self.opt = vars(opt) # Load ResNet for processing images if opt.load_resnet: if image_feat_params['model_root']=='': net = getattr(resnet, image_feat_params['model'])(pretrained=True) else: net = getattr(resnet, image_feat_params['model'])() net.load_state_dict( torch.load(os.path.join(image_feat_params['model_root'], image_feat_params['model'] + '.pth'))) my_resnet = myResnet(net) if torch.cuda.is_available(): my_resnet.cuda() my_resnet.eval() my_resnet_batch = ResNetBatch(net) if torch.cuda.is_available(): my_resnet_batch.cuda() self.my_resnet_batch = my_resnet_batch self.my_resnet = my_resnet self.att_size = image_feat_params['att_size'] # Control the input features of the model if diff_feat == None: if self.caption_model == "show_attend_tell": self.diff_feat = True else: self.diff_feat = False else: self.diff_feat = diff_feat def gen_caption_from_feat(self, feat_target, feat_reference=None): if self.is_relative and feat_reference == None: return None, None if not self.is_relative and not feat_reference == None: return None, None if self.is_relative: if self.diff_feat: fc_feat = torch.cat((feat_target[0], feat_target[0] - feat_reference[0]), dim=-1) att_feat = torch.cat((feat_target[1], feat_target[1] - feat_reference[1]), dim=-1) else: fc_feat = torch.cat((feat_target[0], feat_reference[0]), dim=-1) att_feat = torch.cat((feat_target[1], feat_reference[1]), dim=-1) else: fc_feat = feat_target[0] att_feat = feat_target[1] # Reshape to B x K x C (128,14,14,4096) --> (128,196,4096) att_feat = att_feat.view(att_feat.shape[0], att_feat.shape[1] * att_feat.shape[2], att_feat.shape[-1]) att_masks = np.zeros(att_feat.shape[:2], dtype='float32') for i in range(len(att_feat)): att_masks[i, :att_feat[i].shape[0]] = 1 # set att_masks to None if attention features have same length if att_masks.sum() == att_masks.size: att_masks = None if self.caption_model == 'show_attend_tell': seq, _ = self.model.sample(fc_feat, att_feat, self.opt) else: seq, _ = self.model(fc_feat, att_feat, att_masks=att_masks, opt=self.opt, mode='sample') sents = utils.decode_sequence(self.vocab, seq) return seq, sents def get_vocab_size(self): return len(self.vocab) def get_img_feat(self, img_name): # load the image I = skimage.io.imread(img_name) if len(I.shape) == 2: I = I[:, :, np.newaxis] I = np.concatenate((I, I, I), axis=2) I = I.astype('float32') / 255.0 I = torch.from_numpy(I.transpose([2, 0, 1])) if torch.cuda.is_available(): I = I.cuda() # I = Variable(preprocess(I), volatile=True) with torch.no_grad(): I = preprocess(I) fc, att = self.my_resnet(I, self.att_size) return fc, att def get_img_feat_batch(self, img_names, batchsize=32): if not isinstance(img_names, list): img_names = [img_names] num_images = len(img_names) num_batches = math.ceil(np.float(num_images) / np.float(batchsize)) feature_fc = [] feature_att = [] for id in range(num_batches): startInd = id * batchsize endInd = min((id + 1) * batchsize, num_images) img_names_current_batch = img_names[startInd:endInd] I_current_batch = [] for img_name in img_names_current_batch: I = skimage.io.imread(img_name) if len(I.shape) == 2: I = I[:, :, np.newaxis] I = np.concatenate((I, I, I), axis=2) I = I.astype('float32') / 255.0 I = torch.from_numpy(I.transpose([2, 0, 1])) I_current_batch.append(preprocess(I)) I_current_batch = torch.stack(I_current_batch, dim=0) if torch.cuda.is_available(): I_current_batch = I_current_batch.cuda() # I_current_batch = Variable(I_current_batch, volatile=True) with torch.no_grad(): fc, att = self.my_resnet_batch(I_current_batch, self.att_size) feature_fc.append(fc) feature_att.append(att) feature_fc = torch.cat(feature_fc, dim=0) feature_att = torch.cat(feature_att, dim=0) return feature_fc, feature_att