Spaces:
Build error
Build error
""" | |
****************** COPYRIGHT AND CONFIDENTIALITY INFORMATION ****************** | |
Copyright (c) 2018 [Thomson Licensing] | |
All Rights Reserved | |
This program contains proprietary information which is a trade secret/business \ | |
secret of [Thomson Licensing] and is protected, even if unpublished, under \ | |
applicable Copyright laws (including French droit d'auteur) and/or may be \ | |
subject to one or more patent(s). | |
Recipient is to retain this program in confidence and is not permitted to use \ | |
or make copies thereof other than as permitted in a written agreement with \ | |
[Thomson Licensing] unless otherwise expressly allowed by applicable laws or \ | |
by [Thomson Licensing] under express agreement. | |
Thomson Licensing is a company of the group TECHNICOLOR | |
******************************************************************************* | |
This scripts permits one to reproduce training and experiments of: | |
Engilberge, M., Chevallier, L., Pérez, P., & Cord, M. (2018, April). | |
Finding beans in burgers: Deep semantic-visual embedding with localization. | |
In Proceedings of CVPR (pp. 3984-3993) | |
Author: Martin Engilberge | |
""" | |
import os | |
import nltk | |
import pickle | |
import torch | |
from nltk.tokenize import word_tokenize | |
from torch.autograd import Variable | |
from torch.nn.utils.rnn import pad_sequence | |
from PIL import Image | |
import matplotlib.pyplot as plt | |
class AverageMeter(object): | |
def __init__(self): | |
self.reset() | |
def reset(self): | |
self.val = 0 | |
self.avg = 0 | |
self.sum = 0 | |
self.count = 0 | |
def update(self, val, n=1): | |
self.val = val | |
self.sum += val * n | |
self.count += n | |
self.avg = self.sum / self.count | |
class Namespace: | |
""" Namespace class to manually instantiate joint_embedding model """ | |
def __init__(self, **kwargs): | |
self.__dict__.update(kwargs) | |
def _load_dictionary(dir_st): | |
path_dico = os.path.join(dir_st, 'dictionary.txt') | |
if not os.path.exists(path_dico): | |
print("Invalid path no dictionary found") | |
with open(path_dico, 'r') as handle: | |
dico_list = handle.readlines() | |
dico = {word.strip(): idx for idx, word in enumerate(dico_list)} | |
return dico | |
def preprocess(text): | |
sent_detector = nltk.data.load('tokenizers/punkt/english.pickle') | |
sents = sent_detector.tokenize(text) | |
result = list() | |
for s in sents: | |
tokens = word_tokenize(s) | |
result.append(tokens) | |
return result | |
def flatten(l): | |
return [item for sublist in l for item in sublist] | |
def encode_sentences(sents, embed, dico): | |
sents_list = list() | |
for sent in sents: | |
sent_tok = preprocess(sent)[0] | |
sent_in = Variable(torch.FloatTensor(1, len(sent_tok), 620)) | |
for i, w in enumerate(sent_tok): | |
try: | |
sent_in.data[0, i] = torch.from_numpy(embed[dico[w]]) | |
except KeyError: | |
sent_in.data[0, i] = torch.from_numpy(embed[dico["UNK"]]) | |
sents_list.append(sent_in) | |
return sents_list | |
def encode_sentence(sent, embed, dico, tokenize=True): | |
if tokenize: | |
sent_tok = preprocess(sent)[0] | |
else: | |
sent_tok = sent | |
sent_in = torch.FloatTensor(len(sent_tok), 620) | |
for i, w in enumerate(sent_tok): | |
try: | |
sent_in[i, :620] = torch.from_numpy(embed[dico[w]]) | |
except KeyError: | |
sent_in[i, :620] = torch.from_numpy(embed[dico["UNK"]]) | |
return sent_in | |
def save_checkpoint(state, is_best, model_name, epoch): | |
if is_best: | |
torch.save(state, './weights/best_' + model_name + ".pth.tar") | |
def log_epoch(logger, epoch, train_loss, val_loss, lr, batch_train, batch_val, data_train, data_val, recall): | |
logger.add_scalar('Loss/Train', train_loss, epoch) | |
logger.add_scalar('Loss/Val', val_loss, epoch) | |
logger.add_scalar('Learning/Rate', lr, epoch) | |
logger.add_scalar('Learning/Overfitting', val_loss / train_loss, epoch) | |
logger.add_scalar('Time/Train/Batch Processing', batch_train, epoch) | |
logger.add_scalar('Time/Val/Batch Processing', batch_val, epoch) | |
logger.add_scalar('Time/Train/Data loading', data_train, epoch) | |
logger.add_scalar('Time/Val/Data loading', data_val, epoch) | |
logger.add_scalar('Recall/Val/CapRet/R@1', recall[0][0], epoch) | |
logger.add_scalar('Recall/Val/CapRet/R@5', recall[0][1], epoch) | |
logger.add_scalar('Recall/Val/CapRet/R@10', recall[0][2], epoch) | |
logger.add_scalar('Recall/Val/CapRet/MedR', recall[2], epoch) | |
logger.add_scalar('Recall/Val/ImgRet/R@1', recall[1][0], epoch) | |
logger.add_scalar('Recall/Val/ImgRet/R@5', recall[1][1], epoch) | |
logger.add_scalar('Recall/Val/ImgRet/R@10', recall[1][2], epoch) | |
logger.add_scalar('Recall/Val/ImgRet/MedR', recall[3], epoch) | |
def collate_fn_padded(data): | |
images, captions = zip(*data) | |
images = torch.stack(images, 0) | |
lengths = [len(cap) for cap in captions] | |
targets = pad_sequence(captions, batch_first=True) | |
return images, targets, lengths | |
def collate_fn_cap_padded(data): | |
captions = data | |
lengths = [len(cap) for cap in captions] | |
targets = pad_sequence(captions, batch_first=True) | |
return targets, lengths | |
def collate_fn_semseg(data): | |
images, size, targets = zip(*data) | |
images = torch.stack(images, 0) | |
return images, size, targets | |
def collate_fn_img_padded(data): | |
images = data | |
images = torch.stack(images, 0) | |
return images | |
def load_obj(path): | |
with open(os.path.normpath(path + '.pkl'), 'rb') as f: | |
return pickle.load(f) | |
def save_obj(obj, path): | |
with open(os.path.normpath(path + '.pkl'), 'wb') as f: | |
pickle.dump(obj, f, pickle.HIGHEST_PROTOCOL) | |
def show_imgs(imgs_path): | |
plt.ion() | |
for i, img_path in enumerate(imgs_path): | |
img = Image.open(img_path) | |
plt.figure("Image") # 图像窗口名称 | |
plt.imshow(img) | |
plt.axis('on') # 关掉坐标轴为 off | |
plt.title('image_{}'.format(i)) # 图像题目 | |
plt.ioff() | |
plt.show() | |
plt.close() | |