atticus's picture
completed
30a0ec5
raw
history blame
No virus
6.02 kB
"""
****************** COPYRIGHT AND CONFIDENTIALITY INFORMATION ******************
Copyright (c) 2018 [Thomson Licensing]
All Rights Reserved
This program contains proprietary information which is a trade secret/business \
secret of [Thomson Licensing] and is protected, even if unpublished, under \
applicable Copyright laws (including French droit d'auteur) and/or may be \
subject to one or more patent(s).
Recipient is to retain this program in confidence and is not permitted to use \
or make copies thereof other than as permitted in a written agreement with \
[Thomson Licensing] unless otherwise expressly allowed by applicable laws or \
by [Thomson Licensing] under express agreement.
Thomson Licensing is a company of the group TECHNICOLOR
*******************************************************************************
This scripts permits one to reproduce training and experiments of:
Engilberge, M., Chevallier, L., Pérez, P., & Cord, M. (2018, April).
Finding beans in burgers: Deep semantic-visual embedding with localization.
In Proceedings of CVPR (pp. 3984-3993)
Author: Martin Engilberge
"""
import os
import nltk
import pickle
import torch
from nltk.tokenize import word_tokenize
from torch.autograd import Variable
from torch.nn.utils.rnn import pad_sequence
from PIL import Image
import matplotlib.pyplot as plt
class AverageMeter(object):
def __init__(self):
self.reset()
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
class Namespace:
""" Namespace class to manually instantiate joint_embedding model """
def __init__(self, **kwargs):
self.__dict__.update(kwargs)
def _load_dictionary(dir_st):
path_dico = os.path.join(dir_st, 'dictionary.txt')
if not os.path.exists(path_dico):
print("Invalid path no dictionary found")
with open(path_dico, 'r') as handle:
dico_list = handle.readlines()
dico = {word.strip(): idx for idx, word in enumerate(dico_list)}
return dico
def preprocess(text):
sent_detector = nltk.data.load('tokenizers/punkt/english.pickle')
sents = sent_detector.tokenize(text)
result = list()
for s in sents:
tokens = word_tokenize(s)
result.append(tokens)
return result
def flatten(l):
return [item for sublist in l for item in sublist]
def encode_sentences(sents, embed, dico):
sents_list = list()
for sent in sents:
sent_tok = preprocess(sent)[0]
sent_in = Variable(torch.FloatTensor(1, len(sent_tok), 620))
for i, w in enumerate(sent_tok):
try:
sent_in.data[0, i] = torch.from_numpy(embed[dico[w]])
except KeyError:
sent_in.data[0, i] = torch.from_numpy(embed[dico["UNK"]])
sents_list.append(sent_in)
return sents_list
def encode_sentence(sent, embed, dico, tokenize=True):
if tokenize:
sent_tok = preprocess(sent)[0]
else:
sent_tok = sent
sent_in = torch.FloatTensor(len(sent_tok), 620)
for i, w in enumerate(sent_tok):
try:
sent_in[i, :620] = torch.from_numpy(embed[dico[w]])
except KeyError:
sent_in[i, :620] = torch.from_numpy(embed[dico["UNK"]])
return sent_in
def save_checkpoint(state, is_best, model_name, epoch):
if is_best:
torch.save(state, './weights/best_' + model_name + ".pth.tar")
def log_epoch(logger, epoch, train_loss, val_loss, lr, batch_train, batch_val, data_train, data_val, recall):
logger.add_scalar('Loss/Train', train_loss, epoch)
logger.add_scalar('Loss/Val', val_loss, epoch)
logger.add_scalar('Learning/Rate', lr, epoch)
logger.add_scalar('Learning/Overfitting', val_loss / train_loss, epoch)
logger.add_scalar('Time/Train/Batch Processing', batch_train, epoch)
logger.add_scalar('Time/Val/Batch Processing', batch_val, epoch)
logger.add_scalar('Time/Train/Data loading', data_train, epoch)
logger.add_scalar('Time/Val/Data loading', data_val, epoch)
logger.add_scalar('Recall/Val/CapRet/R@1', recall[0][0], epoch)
logger.add_scalar('Recall/Val/CapRet/R@5', recall[0][1], epoch)
logger.add_scalar('Recall/Val/CapRet/R@10', recall[0][2], epoch)
logger.add_scalar('Recall/Val/CapRet/MedR', recall[2], epoch)
logger.add_scalar('Recall/Val/ImgRet/R@1', recall[1][0], epoch)
logger.add_scalar('Recall/Val/ImgRet/R@5', recall[1][1], epoch)
logger.add_scalar('Recall/Val/ImgRet/R@10', recall[1][2], epoch)
logger.add_scalar('Recall/Val/ImgRet/MedR', recall[3], epoch)
def collate_fn_padded(data):
images, captions = zip(*data)
images = torch.stack(images, 0)
lengths = [len(cap) for cap in captions]
targets = pad_sequence(captions, batch_first=True)
return images, targets, lengths
def collate_fn_cap_padded(data):
captions = data
lengths = [len(cap) for cap in captions]
targets = pad_sequence(captions, batch_first=True)
return targets, lengths
def collate_fn_semseg(data):
images, size, targets = zip(*data)
images = torch.stack(images, 0)
return images, size, targets
def collate_fn_img_padded(data):
images = data
images = torch.stack(images, 0)
return images
def load_obj(path):
with open(os.path.normpath(path + '.pkl'), 'rb') as f:
return pickle.load(f)
def save_obj(obj, path):
with open(os.path.normpath(path + '.pkl'), 'wb') as f:
pickle.dump(obj, f, pickle.HIGHEST_PROTOCOL)
def show_imgs(imgs_path):
plt.ion()
for i, img_path in enumerate(imgs_path):
img = Image.open(img_path)
plt.figure("Image") # 图像窗口名称
plt.imshow(img)
plt.axis('on') # 关掉坐标轴为 off
plt.title('image_{}'.format(i)) # 图像题目
plt.ioff()
plt.show()
plt.close()