File size: 5,171 Bytes
950c874
 
 
2836e50
950c874
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
953580c
 
1e7fce7
950c874
dfbeba0
953580c
 
950c874
0550960
 
 
 
 
 
 
 
950c874
 
 
 
 
 
 
 
 
 
 
 
953580c
 
 
950c874
 
0550960
9666011
0550960
 
3b4c7a3
 
 
 
 
 
 
 
 
 
 
 
1e7fce7
3b4c7a3
 
 
 
 
1e7fce7
3b4c7a3
 
 
 
953580c
 
0550960
 
 
 
 
 
 
 
 
 
 
 
953580c
 
 
 
 
 
 
 
950c874
953580c
362a148
 
953580c
 
950c874
953580c
 
2836e50
953580c
 
2836e50
953580c
 
 
 
 
3b4c7a3
1e7fce7
 
 
953580c
950c874
 
 
3b4c7a3
0550960
 
950c874
0550960
950c874
 
 
 
 
 
 
 
 
 
 
 
 
f733bb1
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
import pickle
from pathlib import Path

import gradio as gr
import torch
# from loguru import logger
from PIL import Image
# from sentence_transformers import SentenceTransformer, util

import os 

import argparse
import re
import time

import numpy as np
from numpy.__config__ import show
import torch


from misc.model import img_embedding, joint_embedding
from torch.utils.data import DataLoader, dataset

from misc.dataset import TextDataset
from misc.utils import collate_fn_cap_padded
from torch.utils.data import DataLoader
from misc.utils import load_obj 
from misc.evaluation import recallTopK 

from misc.utils import show_imgs
import sys 
from misc.dataset import TextEncoder 
import requests 
from io import BytesIO
from translate import Translator
from torchvision import transforms

device = torch.device("cpu")
batch_size = 1
topK = 5

T2I = "以文搜图"
I2I = "以图搜图"

DDT = "双塔动态嵌入"
UEFDT = "双塔联合融合"
IEFDT = "双塔嵌入融合"
ViLT = "视觉语言预训练"

model_path =  "data/best_model.pth.tar"
# model = SentenceTransformer("clip-ViT-B-32")

def download_url_img(url):

    try:
        response = requests.get(url, timeout=3)
    except Exception as e:
        print(str(e))
        return False, []
    if response is not None and response.status_code == 200:
        input_image_data = response.content
        image=Image.open(BytesIO(input_image_data))
        return True, image
    return False, []


def search(mode, method, image, text):

    translator = Translator(from_lang="chinese",to_lang="english")
    text = translator.translate(text)
    if mode == T2I:
        dataset = torch.Tensor(encoder.encode(text)).unsqueeze(dim=0)
        dataset_loader = DataLoader(dataset, batch_size=batch_size, num_workers=1, pin_memory=True, collate_fn=collate_fn_cap_padded)
        caps_enc = list()
        for i, (caps, length) in enumerate(dataset_loader, 0):
            input_caps = caps
            with torch.no_grad():
                _, output_emb = join_emb(None, input_caps, length)
            caps_enc.append(output_emb)
        _stack = np.vstack(caps_enc)
        
    elif mode == I2I:
        dataset = normalize(torch.Tensor(image).permute(2, 0, 1)).unsqueeze(dim=0)
        dataset_loader = DataLoader(dataset, batch_size=batch_size, num_workers=1, pin_memory=True, collate_fn=collate_fn_cap_padded)
        img_enc = list()
        for i, (imgs, length) in enumerate(dataset_loader, 0):
            input_imgs = imgs
            with torch.no_grad():
                output_emb, _ = join_emb(input_imgs, None, None)
            img_enc.append(output_emb)
        _stack = np.vstack(img_enc)

    recall_imgs = recallTopK(_stack, imgs_emb, imgs_url, ks=100)
    res = []
    idx = 0
    tmp = []
    swap_width = 5
    if method == ViLT: 
        pass
    else:
        if method == DDT: swap_width = 5
        elif method == UEFDT: swap_width = 3
        elif method == IEFDT: swap_width = 2
        tmp = recall_imgs[: swap_width]
        recall_imgs[: swap_width] = recall_imgs[swap_width: swap_width * 2]
        recall_imgs[swap_width: swap_width * 2] = tmp

    for img_url in recall_imgs:
        if idx == topK:
            break
        b, img = download_url_img(img_url)
        if b:
            res.append(img)
            idx += 1
    return res

if __name__ == "__main__":
    import nltk
    nltk.download('punkt')
    # print("Loading model from:", model_path)
    checkpoint = torch.load(model_path, map_location=lambda storage, loc: storage)

    join_emb = joint_embedding(checkpoint['args_dict'])
    join_emb.load_state_dict(checkpoint["state_dict"])

    for param in join_emb.parameters():
        param.requires_grad = False

    join_emb.to(device)
    join_emb.eval()
    encoder = TextEncoder()
    imgs_emb_file_path = "./coco_img_emb"
    imgs_emb, imgs_path = load_obj(imgs_emb_file_path)
    imgs_url = [os.path.join("http://images.cocodataset.org/train2017", img_path.strip().split('_')[-1]) for img_path in imgs_path]

    normalize = transforms.Normalize(mean=[0.485 * 255, 0.456 * 255, 0.406 * 255], std=[0.229 * 255, 0.224 * 255, 0.225 * 255])

    print("prepare done!")
    iface = gr.Interface(
        fn=search,
        inputs=[
            gr.inputs.Radio([I2I, T2I]),
            gr.inputs.Radio([DDT, UEFDT, IEFDT, ViLT]),
            gr.inputs.Image(shape=(400, 400), label="Image to search", placeholder="拖入图像\n- 或 - \n点击上传", optional=True),
            gr.inputs.Textbox(
                lines=1, label="Text query", placeholder="请输入待查询文本...",
            ),
        ],
        theme="grass",
        outputs=[
        gr.outputs.Image(type="auto", label="1st Best match"), 
        gr.outputs.Image(type="auto", label="2nd Best match"), 
        gr.outputs.Image(type="auto", label="3rd Best match"),
        gr.outputs.Image(type="auto", label="4rd Best match"),
        gr.outputs.Image(type="auto", label="5rd Best match")
        ],
        title="HUST毕业设计-图文检索系统",
        description="请输入图片或文本,将为您展示相关的图片:",
    )
    iface.launch(share=False)