File size: 4,541 Bytes
950c874
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
"""
****************** COPYRIGHT AND CONFIDENTIALITY INFORMATION ******************
Copyright (c) 2018 [Thomson Licensing]
All Rights Reserved
This program contains proprietary information which is a trade secret/business \
secret of [Thomson Licensing] and is protected, even if unpublished, under \
applicable Copyright laws (including French droit d'auteur) and/or may be \
subject to one or more patent(s).
Recipient is to retain this program in confidence and is not permitted to use \
or make copies thereof other than as permitted in a written agreement with \
[Thomson Licensing] unless otherwise expressly allowed by applicable laws or \
by [Thomson Licensing] under express agreement.
Thomson Licensing is a company of the group TECHNICOLOR
*******************************************************************************
This scripts permits one to reproduce training and experiments of:
    Engilberge, M., Chevallier, L., Pérez, P., & Cord, M. (2018, April).
    Finding beans in burgers: Deep semantic-visual embedding with localization.
    In Proceedings of CVPR (pp. 3984-3993)

Author: Martin Engilberge
"""

import argparse
import re
import time

import numpy as np
from numpy.__config__ import show
import torch


from misc.model import img_embedding, joint_embedding
from torch.utils.data import DataLoader, dataset

from misc.dataset import TextDataset
from misc.utils import collate_fn_cap_padded
from torch.utils.data import DataLoader
from misc.utils import load_obj 
from misc.evaluation import recallTopK 

from misc.utils import show_imgs
import sys 
from misc.dataset import TextEncoder 

device = torch.device("cuda")
# device = torch.device("cpu") # uncomment to run with cpu

if __name__ == '__main__':

    parser = argparse.ArgumentParser(description='Extract embedding representation for images')
    parser.add_argument("-p", '--path', dest="model_path", help='Path to the weights of the model to evaluate')
    parser.add_argument("-d", '--data', dest="data_path", help='path to the file containing the sentence to embed')
    parser.add_argument("-bs", "--batch_size", help="The size of the batches", type=int, default=1)

    args = parser.parse_args()

    print("Loading model from:", args.model_path)
    checkpoint = torch.load(args.model_path, map_location=lambda storage, loc: storage)

    join_emb = joint_embedding(checkpoint['args_dict'])
    join_emb.load_state_dict(checkpoint["state_dict"])

    for param in join_emb.parameters():
        param.requires_grad = False

    join_emb.to(device)
    join_emb.eval()

    encoder = TextEncoder()
    print("Loading model done")
    # (4) design intersection mode.
    print("Please input your description of the image that you wanna search >>>")
    for line in sys.stdin:
        
        t0 = time.time() 
        cap_str = line.strip() 
        # with open(args.data_path, 'w') as cap_file:
        #     cap_file.writelines(cap_str)
        t1 = time.time()
        print("text is embedding ...")
        dataset = torch.Tensor(encoder.encode(cap_str)).unsqueeze(dim=0)
        t111 = time.time()
        dataset_loader = DataLoader(dataset, batch_size=args.batch_size, num_workers=1, pin_memory=True, collate_fn=collate_fn_cap_padded)
        t11 = time.time() 
        caps_enc = list()
        for i, (caps, length) in enumerate(dataset_loader, 0):
            input_caps = caps.to(device)
            with torch.no_grad():
                _, output_emb = join_emb(None, input_caps, length)
            caps_enc.append(output_emb.cpu().data.numpy())

        t12 = time.time() 
        caps_stack = np.vstack(caps_enc)
        # print(t11 - t1, t12 - t11, t111 - t1)
        
        t2 = time.time()
        print("recall from resources ...")
        # (1) load candidate imgs from saved embeding pkl file. 
        imgs_emb_file_path = "/home/atticus/proj/matching/DSVE/imgs_embed/v20210915_01_9408/allImg"
        # imgs_emb(40775, 2400)
        imgs_emb, imgs_path = load_obj(imgs_emb_file_path)
        # (2) calculate the sim between cap and imgs. 
        # (3) rank imgs and display the searching result. 
        recall_imgs = recallTopK(caps_stack, imgs_emb, imgs_path, ks=5)

        t3 = time.time() 
        show_imgs(imgs_path=recall_imgs)
        
        # print("input stage time: {} \n text embedding stage time: {} \n recall stage time: {}".format(t1 - t0, t2 - t1, t3 - t2))

        print("======== current epoch done ========")
        print("Please input your description of the image that you wanna search >>>")