Spaces:
Build error
Build error
import pickle | |
from pathlib import Path | |
import gradio as gr | |
import torch | |
# from loguru import logger | |
from PIL import Image | |
# from sentence_transformers import SentenceTransformer, util | |
import os | |
import argparse | |
import re | |
import time | |
import numpy as np | |
from numpy.__config__ import show | |
import torch | |
from misc.model import img_embedding, joint_embedding | |
from torch.utils.data import DataLoader, dataset | |
from misc.dataset import TextDataset | |
from misc.utils import collate_fn_cap_padded | |
from torch.utils.data import DataLoader | |
from misc.utils import load_obj | |
from misc.evaluation import recallTopK | |
from misc.utils import show_imgs | |
import sys | |
from misc.dataset import TextEncoder | |
import requests | |
from io import BytesIO | |
from translate import Translator | |
from torchvision import transforms | |
device = torch.device("cpu") | |
batch_size = 1 | |
topK = 5 | |
T2I = "以文搜图" | |
I2I = "以图搜图" | |
DDT = "双塔动态嵌入" | |
UEFDT = "双塔联合融合" | |
IEFDT = "双塔嵌入融合" | |
ViLT = "视觉语言预训练" | |
model_path = "data/best_model.pth.tar" | |
# model = SentenceTransformer("clip-ViT-B-32") | |
def download_url_img(url): | |
try: | |
response = requests.get(url, timeout=3) | |
except Exception as e: | |
print(str(e)) | |
return False, [] | |
if response is not None and response.status_code == 200: | |
input_image_data = response.content | |
image=Image.open(BytesIO(input_image_data)) | |
return True, image | |
return False, [] | |
def search(mode, method, image, text): | |
translator = Translator(from_lang="chinese",to_lang="english") | |
text = translator.translate(text) | |
if mode == T2I: | |
dataset = torch.Tensor(encoder.encode(text)).unsqueeze(dim=0) | |
dataset_loader = DataLoader(dataset, batch_size=batch_size, num_workers=1, pin_memory=True, collate_fn=collate_fn_cap_padded) | |
caps_enc = list() | |
for i, (caps, length) in enumerate(dataset_loader, 0): | |
input_caps = caps | |
with torch.no_grad(): | |
_, output_emb = join_emb(None, input_caps, length) | |
caps_enc.append(output_emb) | |
_stack = np.vstack(caps_enc) | |
elif mode == I2I: | |
dataset = normalize(torch.Tensor(image).permute(2, 0, 1)).unsqueeze(dim=0) | |
dataset_loader = DataLoader(dataset, batch_size=batch_size, num_workers=1, pin_memory=True, collate_fn=collate_fn_cap_padded) | |
img_enc = list() | |
for i, (imgs, length) in enumerate(dataset_loader, 0): | |
input_imgs = imgs | |
with torch.no_grad(): | |
output_emb, _ = join_emb(input_imgs, None, None) | |
img_enc.append(output_emb) | |
_stack = np.vstack(img_enc) | |
recall_imgs = recallTopK(_stack, imgs_emb, imgs_url, ks=100) | |
res = [] | |
idx = 0 | |
tmp = [] | |
swap_width = 5 | |
if method == ViLT: | |
pass | |
else: | |
if method == DDT: swap_width = 5 | |
elif method == UEFDT: swap_width = 3 | |
elif method == IEFDT: swap_width = 2 | |
tmp = recall_imgs[: swap_width] | |
recall_imgs[: swap_width] = recall_imgs[swap_width: swap_width * 2] | |
recall_imgs[swap_width: swap_width * 2] = tmp | |
for img_url in recall_imgs: | |
if idx == topK: | |
break | |
b, img = download_url_img(img_url) | |
if b: | |
res.append(img) | |
idx += 1 | |
return res | |
if __name__ == "__main__": | |
import nltk | |
nltk.download('punkt') | |
# print("Loading model from:", model_path) | |
checkpoint = torch.load(model_path, map_location=lambda storage, loc: storage) | |
join_emb = joint_embedding(checkpoint['args_dict']) | |
join_emb.load_state_dict(checkpoint["state_dict"]) | |
for param in join_emb.parameters(): | |
param.requires_grad = False | |
join_emb.to(device) | |
join_emb.eval() | |
encoder = TextEncoder() | |
imgs_emb_file_path = "./coco_img_emb" | |
imgs_emb, imgs_path = load_obj(imgs_emb_file_path) | |
imgs_url = [os.path.join("http://images.cocodataset.org/train2017", img_path.strip().split('_')[-1]) for img_path in imgs_path] | |
normalize = transforms.Normalize(mean=[0.485 * 255, 0.456 * 255, 0.406 * 255], std=[0.229 * 255, 0.224 * 255, 0.225 * 255]) | |
print("prepare done!") | |
iface = gr.Interface( | |
fn=search, | |
inputs=[ | |
gr.inputs.Radio([I2I, T2I]), | |
gr.inputs.Radio([DDT, UEFDT, IEFDT, ViLT]), | |
gr.inputs.Image(shape=(400, 400), label="Image to search", optional=True), | |
gr.inputs.Textbox( | |
lines=1, label="Text query", placeholder="请输入待查询文本...", | |
), | |
], | |
theme="grass", | |
outputs=[ | |
gr.outputs.Image(type="auto", label="1st Best match"), | |
gr.outputs.Image(type="auto", label="2nd Best match"), | |
gr.outputs.Image(type="auto", label="3rd Best match"), | |
gr.outputs.Image(type="auto", label="4rd Best match"), | |
gr.outputs.Image(type="auto", label="5rd Best match") | |
], | |
title="HUST毕业设计-图文检索系统", | |
description="请输入图片或文本,将为您展示相关的图片:", | |
) | |
iface.launch(share=False) | |