pere's picture
added SentEval
cd5fcb4
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
#
"""
Image Annotation/Search for COCO with Pytorch
"""
from __future__ import absolute_import, division, unicode_literals
import logging
import copy
import numpy as np
import torch
from torch import nn
from torch.autograd import Variable
import torch.optim as optim
class COCOProjNet(nn.Module):
def __init__(self, config):
super(COCOProjNet, self).__init__()
self.imgdim = config['imgdim']
self.sentdim = config['sentdim']
self.projdim = config['projdim']
self.imgproj = nn.Sequential(
nn.Linear(self.imgdim, self.projdim),
)
self.sentproj = nn.Sequential(
nn.Linear(self.sentdim, self.projdim),
)
def forward(self, img, sent, imgc, sentc):
# imgc : (bsize, ncontrast, imgdim)
# sentc : (bsize, ncontrast, sentdim)
# img : (bsize, imgdim)
# sent : (bsize, sentdim)
img = img.unsqueeze(1).expand_as(imgc).contiguous()
img = img.view(-1, self.imgdim)
imgc = imgc.view(-1, self.imgdim)
sent = sent.unsqueeze(1).expand_as(sentc).contiguous()
sent = sent.view(-1, self.sentdim)
sentc = sentc.view(-1, self.sentdim)
imgproj = self.imgproj(img)
imgproj = imgproj / torch.sqrt(torch.pow(imgproj, 2).sum(1, keepdim=True)).expand_as(imgproj)
imgcproj = self.imgproj(imgc)
imgcproj = imgcproj / torch.sqrt(torch.pow(imgcproj, 2).sum(1, keepdim=True)).expand_as(imgcproj)
sentproj = self.sentproj(sent)
sentproj = sentproj / torch.sqrt(torch.pow(sentproj, 2).sum(1, keepdim=True)).expand_as(sentproj)
sentcproj = self.sentproj(sentc)
sentcproj = sentcproj / torch.sqrt(torch.pow(sentcproj, 2).sum(1, keepdim=True)).expand_as(sentcproj)
# (bsize*ncontrast, projdim)
anchor1 = torch.sum((imgproj*sentproj), 1)
anchor2 = torch.sum((sentproj*imgproj), 1)
img_sentc = torch.sum((imgproj*sentcproj), 1)
sent_imgc = torch.sum((sentproj*imgcproj), 1)
# (bsize*ncontrast)
return anchor1, anchor2, img_sentc, sent_imgc
def proj_sentence(self, sent):
output = self.sentproj(sent)
output = output / torch.sqrt(torch.pow(output, 2).sum(1, keepdim=True)).expand_as(output)
return output # (bsize, projdim)
def proj_image(self, img):
output = self.imgproj(img)
output = output / torch.sqrt(torch.pow(output, 2).sum(1, keepdim=True)).expand_as(output)
return output # (bsize, projdim)
class PairwiseRankingLoss(nn.Module):
"""
Pairwise ranking loss
"""
def __init__(self, margin):
super(PairwiseRankingLoss, self).__init__()
self.margin = margin
def forward(self, anchor1, anchor2, img_sentc, sent_imgc):
cost_sent = torch.clamp(self.margin - anchor1 + img_sentc,
min=0.0).sum()
cost_img = torch.clamp(self.margin - anchor2 + sent_imgc,
min=0.0).sum()
loss = cost_sent + cost_img
return loss
class ImageSentenceRankingPytorch(object):
# Image Sentence Ranking on COCO with Pytorch
def __init__(self, train, valid, test, config):
# fix seed
self.seed = config['seed']
np.random.seed(self.seed)
torch.manual_seed(self.seed)
torch.cuda.manual_seed(self.seed)
self.train = train
self.valid = valid
self.test = test
self.imgdim = len(train['imgfeat'][0])
self.sentdim = len(train['sentfeat'][0])
self.projdim = config['projdim']
self.margin = config['margin']
self.batch_size = 128
self.ncontrast = 30
self.maxepoch = 20
self.early_stop = True
config_model = {'imgdim': self.imgdim,'sentdim': self.sentdim,
'projdim': self.projdim}
self.model = COCOProjNet(config_model).cuda()
self.loss_fn = PairwiseRankingLoss(margin=self.margin).cuda()
self.optimizer = optim.Adam(self.model.parameters())
def prepare_data(self, trainTxt, trainImg, devTxt, devImg,
testTxt, testImg):
trainTxt = torch.FloatTensor(trainTxt)
trainImg = torch.FloatTensor(trainImg)
devTxt = torch.FloatTensor(devTxt).cuda()
devImg = torch.FloatTensor(devImg).cuda()
testTxt = torch.FloatTensor(testTxt).cuda()
testImg = torch.FloatTensor(testImg).cuda()
return trainTxt, trainImg, devTxt, devImg, testTxt, testImg
def run(self):
self.nepoch = 0
bestdevscore = -1
early_stop_count = 0
stop_train = False
# Preparing data
logging.info('prepare data')
trainTxt, trainImg, devTxt, devImg, testTxt, testImg = \
self.prepare_data(self.train['sentfeat'], self.train['imgfeat'],
self.valid['sentfeat'], self.valid['imgfeat'],
self.test['sentfeat'], self.test['imgfeat'])
# Training
while not stop_train and self.nepoch <= self.maxepoch:
logging.info('start epoch')
self.trainepoch(trainTxt, trainImg, devTxt, devImg, nepoches=1)
logging.info('Epoch {0} finished'.format(self.nepoch))
results = {'i2t': {'r1': 0, 'r5': 0, 'r10': 0, 'medr': 0},
't2i': {'r1': 0, 'r5': 0, 'r10': 0, 'medr': 0},
'dev': bestdevscore}
score = 0
for i in range(5):
devTxt_i = devTxt[i*5000:(i+1)*5000]
devImg_i = devImg[i*5000:(i+1)*5000]
# Compute dev ranks img2txt
r1_i2t, r5_i2t, r10_i2t, medr_i2t = self.i2t(devImg_i,
devTxt_i)
results['i2t']['r1'] += r1_i2t / 5
results['i2t']['r5'] += r5_i2t / 5
results['i2t']['r10'] += r10_i2t / 5
results['i2t']['medr'] += medr_i2t / 5
logging.info("Image to text: {0}, {1}, {2}, {3}"
.format(r1_i2t, r5_i2t, r10_i2t, medr_i2t))
# Compute dev ranks txt2img
r1_t2i, r5_t2i, r10_t2i, medr_t2i = self.t2i(devImg_i,
devTxt_i)
results['t2i']['r1'] += r1_t2i / 5
results['t2i']['r5'] += r5_t2i / 5
results['t2i']['r10'] += r10_t2i / 5
results['t2i']['medr'] += medr_t2i / 5
logging.info("Text to Image: {0}, {1}, {2}, {3}"
.format(r1_t2i, r5_t2i, r10_t2i, medr_t2i))
score += (r1_i2t + r5_i2t + r10_i2t +
r1_t2i + r5_t2i + r10_t2i) / 5
logging.info("Dev mean Text to Image: {0}, {1}, {2}, {3}".format(
results['t2i']['r1'], results['t2i']['r5'],
results['t2i']['r10'], results['t2i']['medr']))
logging.info("Dev mean Image to text: {0}, {1}, {2}, {3}".format(
results['i2t']['r1'], results['i2t']['r5'],
results['i2t']['r10'], results['i2t']['medr']))
# early stop on Pearson
if score > bestdevscore:
bestdevscore = score
bestmodel = copy.deepcopy(self.model)
elif self.early_stop:
if early_stop_count >= 3:
stop_train = True
early_stop_count += 1
self.model = bestmodel
# Compute test for the 5 splits
results = {'i2t': {'r1': 0, 'r5': 0, 'r10': 0, 'medr': 0},
't2i': {'r1': 0, 'r5': 0, 'r10': 0, 'medr': 0},
'dev': bestdevscore}
for i in range(5):
testTxt_i = testTxt[i*5000:(i+1)*5000]
testImg_i = testImg[i*5000:(i+1)*5000]
# Compute test ranks img2txt
r1_i2t, r5_i2t, r10_i2t, medr_i2t = self.i2t(testImg_i, testTxt_i)
results['i2t']['r1'] += r1_i2t / 5
results['i2t']['r5'] += r5_i2t / 5
results['i2t']['r10'] += r10_i2t / 5
results['i2t']['medr'] += medr_i2t / 5
# Compute test ranks txt2img
r1_t2i, r5_t2i, r10_t2i, medr_t2i = self.t2i(testImg_i, testTxt_i)
results['t2i']['r1'] += r1_t2i / 5
results['t2i']['r5'] += r5_t2i / 5
results['t2i']['r10'] += r10_t2i / 5
results['t2i']['medr'] += medr_t2i / 5
return bestdevscore, results['i2t']['r1'], results['i2t']['r5'], \
results['i2t']['r10'], results['i2t']['medr'], \
results['t2i']['r1'], results['t2i']['r5'], \
results['t2i']['r10'], results['t2i']['medr']
def trainepoch(self, trainTxt, trainImg, devTxt, devImg, nepoches=1):
self.model.train()
for _ in range(self.nepoch, self.nepoch + nepoches):
permutation = list(np.random.permutation(len(trainTxt)))
all_costs = []
for i in range(0, len(trainTxt), self.batch_size):
# forward
if i % (self.batch_size*500) == 0 and i > 0:
logging.info('samples : {0}'.format(i))
r1_i2t, r5_i2t, r10_i2t, medr_i2t = self.i2t(devImg,
devTxt)
logging.info("Image to text: {0}, {1}, {2}, {3}".format(
r1_i2t, r5_i2t, r10_i2t, medr_i2t))
# Compute test ranks txt2img
r1_t2i, r5_t2i, r10_t2i, medr_t2i = self.t2i(devImg,
devTxt)
logging.info("Text to Image: {0}, {1}, {2}, {3}".format(
r1_t2i, r5_t2i, r10_t2i, medr_t2i))
idx = torch.LongTensor(permutation[i:i + self.batch_size])
imgbatch = Variable(trainImg.index_select(0, idx)).cuda()
sentbatch = Variable(trainTxt.index_select(0, idx)).cuda()
idximgc = np.random.choice(permutation[:i] +
permutation[i + self.batch_size:],
self.ncontrast*idx.size(0))
idxsentc = np.random.choice(permutation[:i] +
permutation[i + self.batch_size:],
self.ncontrast*idx.size(0))
idximgc = torch.LongTensor(idximgc)
idxsentc = torch.LongTensor(idxsentc)
# Get indexes for contrastive images and sentences
imgcbatch = Variable(trainImg.index_select(0, idximgc)).view(
-1, self.ncontrast, self.imgdim).cuda()
sentcbatch = Variable(trainTxt.index_select(0, idxsentc)).view(
-1, self.ncontrast, self.sentdim).cuda()
anchor1, anchor2, img_sentc, sent_imgc = self.model(
imgbatch, sentbatch, imgcbatch, sentcbatch)
# loss
loss = self.loss_fn(anchor1, anchor2, img_sentc, sent_imgc)
all_costs.append(loss.data.item())
# backward
self.optimizer.zero_grad()
loss.backward()
# Update parameters
self.optimizer.step()
self.nepoch += nepoches
def t2i(self, images, captions):
"""
Images: (5N, imgdim) matrix of images
Captions: (5N, sentdim) matrix of captions
"""
with torch.no_grad():
# Project images and captions
img_embed, sent_embed = [], []
for i in range(0, len(images), self.batch_size):
img_embed.append(self.model.proj_image(
Variable(images[i:i + self.batch_size])))
sent_embed.append(self.model.proj_sentence(
Variable(captions[i:i + self.batch_size])))
img_embed = torch.cat(img_embed, 0).data
sent_embed = torch.cat(sent_embed, 0).data
npts = int(img_embed.size(0) / 5)
idxs = torch.cuda.LongTensor(range(0, len(img_embed), 5))
ims = img_embed.index_select(0, idxs)
ranks = np.zeros(5 * npts)
for index in range(npts):
# Get query captions
queries = sent_embed[5*index: 5*index + 5]
# Compute scores
scores = torch.mm(queries, ims.transpose(0, 1)).cpu().numpy()
inds = np.zeros(scores.shape)
for i in range(len(inds)):
inds[i] = np.argsort(scores[i])[::-1]
ranks[5 * index + i] = np.where(inds[i] == index)[0][0]
# Compute metrics
r1 = 100.0 * len(np.where(ranks < 1)[0]) / len(ranks)
r5 = 100.0 * len(np.where(ranks < 5)[0]) / len(ranks)
r10 = 100.0 * len(np.where(ranks < 10)[0]) / len(ranks)
medr = np.floor(np.median(ranks)) + 1
return (r1, r5, r10, medr)
def i2t(self, images, captions):
"""
Images: (5N, imgdim) matrix of images
Captions: (5N, sentdim) matrix of captions
"""
with torch.no_grad():
# Project images and captions
img_embed, sent_embed = [], []
for i in range(0, len(images), self.batch_size):
img_embed.append(self.model.proj_image(
Variable(images[i:i + self.batch_size])))
sent_embed.append(self.model.proj_sentence(
Variable(captions[i:i + self.batch_size])))
img_embed = torch.cat(img_embed, 0).data
sent_embed = torch.cat(sent_embed, 0).data
npts = int(img_embed.size(0) / 5)
index_list = []
ranks = np.zeros(npts)
for index in range(npts):
# Get query image
query_img = img_embed[5 * index]
# Compute scores
scores = torch.mm(query_img.view(1, -1),
sent_embed.transpose(0, 1)).view(-1)
scores = scores.cpu().numpy()
inds = np.argsort(scores)[::-1]
index_list.append(inds[0])
# Score
rank = 1e20
for i in range(5*index, 5*index + 5, 1):
tmp = np.where(inds == i)[0][0]
if tmp < rank:
rank = tmp
ranks[index] = rank
# Compute metrics
r1 = 100.0 * len(np.where(ranks < 1)[0]) / len(ranks)
r5 = 100.0 * len(np.where(ranks < 5)[0]) / len(ranks)
r10 = 100.0 * len(np.where(ranks < 10)[0]) / len(ranks)
medr = np.floor(np.median(ranks)) + 1
return (r1, r5, r10, medr)