mmir_usersim / captioning /captioner.py
yashonwu
modify captioning
5d45228
raw
history blame
10.3 kB
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import json, math
import numpy as np
import os, sys
from six.moves import cPickle
from sys import path
sys.path.insert(0, os.getcwd())
sys.path.insert(0, 'captioning/')
# print('relative captioning is called')
import captioning.utils.opts as opts
import captioning.models as models
from captioning.data.dataloader import *
from captioning.data.dataloaderraw import *
import argparse
import captioning.utils.misc as utils
import torch
import skimage.io
from torch.autograd import Variable
from torchvision import transforms as trn
preprocess = trn.Compose([
# trn.ToTensor(),
trn.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
from captioning.utils.resnet_utils import myResnet
from captioning.utils.resnet_utils import ResNetBatch
import captioning.utils.resnet as resnet
import wget
import tempfile
class object:
def __init__(self):
self.input_fc_dir = ''
self.input_json = ''
self.batch_size = ''
self.id = ''
self.sample_max = 1
self.cnn_model = 'resnet101'
self.model = ''
self.language_eval = 0
self.beam_size = 1
self.temperature = 1.0
return
class Captioner():
def __init__(self, is_relative=True, model_path=None, image_feat_params=None, data_type=None, load_resnet=True, diff_feat=None):
opt = object()
if image_feat_params==None:
image_feat_params = {}
image_feat_params['model'] = 'resnet101'
image_feat_params['model_root'] = ''
image_feat_params['att_size'] = 7
# inputs specific to shoe dataset
infos_path = os.path.join(model_path, 'infos_best.pkl')
model_path = os.path.join(model_path, 'model_best.pth')
opt.infos_path = infos_path
opt.model_path = model_path
opt.beam_size = 1
opt.load_resnet = load_resnet
# load pre-trained model, adjusting if URL
if opt.infos_path.startswith("http:") or opt.infos_path.startswith("https:"):
# create a folder to store the checkpoints for downloading
if not os.path.exists('./checkpoints_usersim'):
os.mkdir('./checkpoints_usersim')
checkpoint_path = os.path.join('./checkpoints_usersim', data_type)
if not os.path.exists(checkpoint_path):
os.mkdir(checkpoint_path)
# set the location for infos
infos_loc = os.path.join(checkpoint_path, 'infos_best.pkl')
if not os.path.exists(infos_loc):
try:
wget.download(opt.infos_path, infos_loc)
except Exception as err:
print(f"[{err}]")
else:
infos_loc = infos_path
if opt.model_path.startswith("http:") or opt.model_path.startswith("https:"):
# create a folder to store the checkpoints for downloading
if not os.path.exists('./checkpoints_usersim'):
os.mkdir('./checkpoints_usersim')
checkpoint_path = os.path.join('./checkpoints_usersim', data_type)
if not os.path.exists(checkpoint_path):
os.mkdir(checkpoint_path)
# set the location for models
model_loc = os.path.join(checkpoint_path, 'model_best.pth')
if not os.path.exists(model_loc):
try:
wget.download(opt.model_path, model_loc)
except Exception as err:
print(f"[{err}]")
opt.model = model_loc
else:
opt.model = model_path
if os.path.exists(infos_loc):
# load existing infos
with open(infos_loc, 'rb') as f:
infos = cPickle.load(f)
self.caption_model = infos["opt"].caption_model
# override and collect parameters
if len(opt.input_fc_dir) == 0:
opt.input_fc_dir = infos['opt'].input_fc_dir
opt.input_att_dir = infos['opt'].input_att_dir
opt.input_label_h5 = infos['opt'].input_label_h5
if len(opt.input_json) == 0:
opt.input_json = infos['opt'].input_json
if opt.batch_size == 0:
opt.batch_size = infos['opt'].batch_size
if len(opt.id) == 0:
opt.id = infos['opt'].id
ignore = ["id", "batch_size", "beam_size", "start_from", "language_eval", "model"]
for k in vars(infos['opt']).keys():
if k not in ignore:
if k in vars(opt):
assert vars(opt)[k] == vars(infos['opt'])[k], k + ' option not consistent'
else:
vars(opt).update({k: vars(infos['opt'])[k]}) # copy over options from model
vocab = infos['vocab'] # ix -> word mapping
# print('opt:', opt)
# Setup the model
opt.vocab = vocab
model = models.setup(opt)
del opt.vocab
if torch.cuda.is_available():
model.load_state_dict(torch.load(opt.model))
model.cuda()
else:
model.load_state_dict(torch.load(opt.model, map_location={'cuda:0': 'cpu'}))
model.eval()
self.is_relative = is_relative
self.model = model
self.vocab = vocab
self.opt = vars(opt)
# Load ResNet for processing images
if opt.load_resnet:
if image_feat_params['model_root']=='':
net = getattr(resnet, image_feat_params['model'])(pretrained=True)
else:
net = getattr(resnet, image_feat_params['model'])()
net.load_state_dict(
torch.load(os.path.join(image_feat_params['model_root'], image_feat_params['model'] + '.pth')))
my_resnet = myResnet(net)
if torch.cuda.is_available():
my_resnet.cuda()
my_resnet.eval()
my_resnet_batch = ResNetBatch(net)
if torch.cuda.is_available():
my_resnet_batch.cuda()
self.my_resnet_batch = my_resnet_batch
self.my_resnet = my_resnet
self.att_size = image_feat_params['att_size']
# Control the input features of the model
if diff_feat == None:
if self.caption_model == "show_attend_tell":
self.diff_feat = True
else:
self.diff_feat = False
else:
self.diff_feat = diff_feat
def gen_caption_from_feat(self, feat_target, feat_reference=None):
if self.is_relative and feat_reference == None:
return None, None
if not self.is_relative and not feat_reference == None:
return None, None
if self.is_relative:
if self.diff_feat:
fc_feat = torch.cat((feat_target[0], feat_target[0] - feat_reference[0]), dim=-1)
att_feat = torch.cat((feat_target[1], feat_target[1] - feat_reference[1]), dim=-1)
else:
fc_feat = torch.cat((feat_target[0], feat_reference[0]), dim=-1)
att_feat = torch.cat((feat_target[1], feat_reference[1]), dim=-1)
else:
fc_feat = feat_target[0]
att_feat = feat_target[1]
# Reshape to B x K x C (128,14,14,4096) --> (128,196,4096)
att_feat = att_feat.view(att_feat.shape[0], att_feat.shape[1] * att_feat.shape[2], att_feat.shape[-1])
att_masks = np.zeros(att_feat.shape[:2], dtype='float32')
for i in range(len(att_feat)):
att_masks[i, :att_feat[i].shape[0]] = 1
# set att_masks to None if attention features have same length
if att_masks.sum() == att_masks.size:
att_masks = None
if self.caption_model == 'show_attend_tell':
seq, _ = self.model.sample(fc_feat, att_feat, self.opt)
else:
seq, _ = self.model(fc_feat, att_feat, att_masks=att_masks, opt=self.opt, mode='sample')
sents = utils.decode_sequence(self.vocab, seq)
return seq, sents
def get_vocab_size(self):
return len(self.vocab)
def get_img_feat(self, img_name):
# load the image
I = skimage.io.imread(img_name)
if len(I.shape) == 2:
I = I[:, :, np.newaxis]
I = np.concatenate((I, I, I), axis=2)
I = I.astype('float32') / 255.0
I = torch.from_numpy(I.transpose([2, 0, 1]))
if torch.cuda.is_available(): I = I.cuda()
# I = Variable(preprocess(I), volatile=True)
with torch.no_grad():
I = preprocess(I)
fc, att = self.my_resnet(I, self.att_size)
return fc, att
def get_img_feat_batch(self, img_names, batchsize=32):
if not isinstance(img_names, list):
img_names = [img_names]
num_images = len(img_names)
num_batches = math.ceil(np.float(num_images) / np.float(batchsize))
feature_fc = []
feature_att = []
for id in range(num_batches):
startInd = id * batchsize
endInd = min((id + 1) * batchsize, num_images)
img_names_current_batch = img_names[startInd:endInd]
I_current_batch = []
for img_name in img_names_current_batch:
I = skimage.io.imread(img_name)
if len(I.shape) == 2:
I = I[:, :, np.newaxis]
I = np.concatenate((I, I, I), axis=2)
I = I.astype('float32') / 255.0
I = torch.from_numpy(I.transpose([2, 0, 1]))
I_current_batch.append(preprocess(I))
I_current_batch = torch.stack(I_current_batch, dim=0)
if torch.cuda.is_available(): I_current_batch = I_current_batch.cuda()
# I_current_batch = Variable(I_current_batch, volatile=True)
with torch.no_grad():
fc, att = self.my_resnet_batch(I_current_batch, self.att_size)
feature_fc.append(fc)
feature_att.append(att)
feature_fc = torch.cat(feature_fc, dim=0)
feature_att = torch.cat(feature_att, dim=0)
return feature_fc, feature_att