""" ****************** COPYRIGHT AND CONFIDENTIALITY INFORMATION ****************** Copyright (c) 2018 [Thomson Licensing] All Rights Reserved This program contains proprietary information which is a trade secret/business \ secret of [Thomson Licensing] and is protected, even if unpublished, under \ applicable Copyright laws (including French droit d'auteur) and/or may be \ subject to one or more patent(s). Recipient is to retain this program in confidence and is not permitted to use \ or make copies thereof other than as permitted in a written agreement with \ [Thomson Licensing] unless otherwise expressly allowed by applicable laws or \ by [Thomson Licensing] under express agreement. Thomson Licensing is a company of the group TECHNICOLOR ******************************************************************************* This scripts permits one to reproduce training and experiments of: Engilberge, M., Chevallier, L., PĂ©rez, P., & Cord, M. (2018, April). Finding beans in burgers: Deep semantic-visual embedding with localization. In Proceedings of CVPR (pp. 3984-3993) Author: Martin Engilberge """ import argparse import re import time import numpy as np from numpy.__config__ import show import torch from misc.model import img_embedding, joint_embedding from torch.utils.data import DataLoader, dataset from misc.dataset import TextDataset from misc.utils import collate_fn_cap_padded from torch.utils.data import DataLoader from misc.utils import load_obj from misc.evaluation import recallTopK from misc.utils import show_imgs import sys from misc.dataset import TextEncoder device = torch.device("cuda") # device = torch.device("cpu") # uncomment to run with cpu if __name__ == '__main__': parser = argparse.ArgumentParser(description='Extract embedding representation for images') parser.add_argument("-p", '--path', dest="model_path", help='Path to the weights of the model to evaluate') parser.add_argument("-d", '--data', dest="data_path", help='path to the file containing the sentence to embed') parser.add_argument("-bs", "--batch_size", help="The size of the batches", type=int, default=1) args = parser.parse_args() print("Loading model from:", args.model_path) checkpoint = torch.load(args.model_path, map_location=lambda storage, loc: storage) join_emb = joint_embedding(checkpoint['args_dict']) join_emb.load_state_dict(checkpoint["state_dict"]) for param in join_emb.parameters(): param.requires_grad = False join_emb.to(device) join_emb.eval() encoder = TextEncoder() print("Loading model done") # (4) design intersection mode. print("Please input your description of the image that you wanna search >>>") for line in sys.stdin: t0 = time.time() cap_str = line.strip() # with open(args.data_path, 'w') as cap_file: # cap_file.writelines(cap_str) t1 = time.time() print("text is embedding ...") dataset = torch.Tensor(encoder.encode(cap_str)).unsqueeze(dim=0) t111 = time.time() dataset_loader = DataLoader(dataset, batch_size=args.batch_size, num_workers=1, pin_memory=True, collate_fn=collate_fn_cap_padded) t11 = time.time() caps_enc = list() for i, (caps, length) in enumerate(dataset_loader, 0): input_caps = caps.to(device) with torch.no_grad(): _, output_emb = join_emb(None, input_caps, length) caps_enc.append(output_emb.cpu().data.numpy()) t12 = time.time() caps_stack = np.vstack(caps_enc) # print(t11 - t1, t12 - t11, t111 - t1) t2 = time.time() print("recall from resources ...") # (1) load candidate imgs from saved embeding pkl file. imgs_emb_file_path = "/home/atticus/proj/matching/DSVE/imgs_embed/v20210915_01_9408/allImg" # imgs_emb(40775, 2400) imgs_emb, imgs_path = load_obj(imgs_emb_file_path) # (2) calculate the sim between cap and imgs. # (3) rank imgs and display the searching result. recall_imgs = recallTopK(caps_stack, imgs_emb, imgs_path, ks=5) t3 = time.time() show_imgs(imgs_path=recall_imgs) # print("input stage time: {} \n text embedding stage time: {} \n recall stage time: {}".format(t1 - t0, t2 - t1, t3 - t2)) print("======== current epoch done ========") print("Please input your description of the image that you wanna search >>>")