File size: 4,206 Bytes
950c874
 
 
2836e50
950c874
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
953580c
 
 
950c874
 
 
953580c
 
950c874
 
 
 
 
 
953580c
 
f763a16
950c874
 
 
 
 
 
 
 
 
 
953580c
 
 
 
 
950c874
 
953580c
950c874
953580c
 
950c874
 
 
953580c
950c874
 
 
 
953580c
 
950c874
953580c
950c874
953580c
950c874
 
 
 
 
953580c
 
 
 
 
 
 
 
 
 
950c874
953580c
 
 
950c874
953580c
 
2836e50
953580c
 
2836e50
953580c
 
 
 
 
 
 
950c874
 
 
953580c
950c874
 
 
 
 
 
 
 
 
 
 
 
 
 
 
953580c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
import pickle
from pathlib import Path

import gradio as gr
import torch
# from loguru import logger
from PIL import Image
# from sentence_transformers import SentenceTransformer, util

import os 

import argparse
import re
import time

import numpy as np
from numpy.__config__ import show
import torch


from misc.model import img_embedding, joint_embedding
from torch.utils.data import DataLoader, dataset

from misc.dataset import TextDataset
from misc.utils import collate_fn_cap_padded
from torch.utils.data import DataLoader
from misc.utils import load_obj 
from misc.evaluation import recallTopK 

from misc.utils import show_imgs
import sys 
from misc.dataset import TextEncoder 
import requests 
import cv2 
from io import BytesIO
from translate import Translator
import cupy as cp


device = torch.device("cuda")
batch_size = 1
topK = 5

T2I = "Text 2 Image"
I2I = "Image 2 Image"
model_path =  "data/best_model.pth.tar"
# model = SentenceTransformer("clip-ViT-B-32")

img_folder = Path("./photos/")

# start 

def download_url_img(url):

    try:
        response = requests.get(url, timeout=3)
    except Exception as e:
        print(str(e))
        return False, []
    if response is not None and response.status_code == 200:
        input_image_data = response.content
        # np_arr = np.asarray(bytearray(input_image_data), np.uint8).reshape(1, -1)
        # parsed_image = cv2.imdecode(np_arr, cv2.IMREAD_UNCHANGED)
        image=Image.open(BytesIO(input_image_data))
        return True, image
    return False, []


def search(mode, text):

    # translator = Translator(from_lang="chinese",to_lang="english")
    # text = translator.translate(text)
    dataset = torch.Tensor(encoder.encode(text)).unsqueeze(dim=0)
    dataset_loader = DataLoader(dataset, batch_size=batch_size, num_workers=1, pin_memory=True, collate_fn=collate_fn_cap_padded)
    caps_enc = list()

    for _, (caps, length) in enumerate(dataset_loader, 0):
        input_caps = caps.to(device)
        with torch.no_grad():
            _, caps_emb = join_emb(None, input_caps, length)
        caps_enc.append(caps_emb)
    caps_stack = cp.vstack(caps_enc)

    imgs_url = [os.path.join("http://images.cocodataset.org/train2017", img_path.strip().split('_')[-1]) for img_path in imgs_path]

    recall_imgs = recallTopK(caps_stack, imgs_emb, imgs_url, ks=100)

    # Cat image downloaded from https://www.flickr.com/photos/blacktigersdream/23119711630
    # cat_image = "./cat_example.jpg"
    # Dog example downloaded from https://upload.wikimedia.org/wikipedia/commons/1/18/Dog_Breeds.jpg
    # dog_image = "./dog_example.jpg"
    res = []
    idx = 0
    for img_url in recall_imgs:
        if idx == topK:
            break
        b, img = download_url_img(img_url)
        if b:
            res.append(img)
            idx += 1
    return res

if __name__ == "__main__":
    # print("Loading model from:", model_path)
    checkpoint = torch.load(model_path, map_location=lambda storage, loc: storage)

    join_emb = joint_embedding(checkpoint['args_dict'])
    join_emb.load_state_dict(checkpoint["state_dict"])

    for param in join_emb.parameters():
        param.requires_grad = False

    join_emb.to(device)
    join_emb.eval()
    encoder = TextEncoder()
    imgs_emb_file_path = "./coco_img_emb"
    imgs_emb, imgs_path = load_obj(imgs_emb_file_path)
    imgs_emb = cp.asarray(imgs_emb)
    print("prepare done!")
    iface = gr.Interface(
        fn=search,
        inputs=[
            gr.inputs.Radio([T2I]),
            gr.inputs.Textbox(
                lines=1, label="Text query", placeholder="Introduce the search text...",
            ),
        ],
        theme="grass",
        outputs=[
        gr.outputs.Image(type="auto", label="1st Best match"), 
        gr.outputs.Image(type="auto", label="2nd Best match"), 
        gr.outputs.Image(type="auto", label="3rd Best match"),
        gr.outputs.Image(type="auto", label="4rd Best match"),
        gr.outputs.Image(type="auto", label="5rd Best match")
        ],
        title="HUST毕业设计-图文检索系统",
        description="请输入图片或文本,将为您展示相关的图片:",
    )
    iface.launch(share=True)