Spaces:
Runtime error
Runtime error
import pickle | |
from pathlib import Path | |
import gradio as gr | |
import torch | |
# from loguru import logger | |
from PIL import Image | |
# from sentence_transformers import SentenceTransformer, util | |
import os | |
import argparse | |
import re | |
import time | |
import numpy as np | |
from numpy.__config__ import show | |
import torch | |
from misc.model import img_embedding, joint_embedding | |
from torch.utils.data import DataLoader, dataset | |
from misc.dataset import TextDataset | |
from misc.utils import collate_fn_cap_padded | |
from torch.utils.data import DataLoader | |
from misc.utils import load_obj | |
from misc.evaluation import recallTopK | |
from misc.utils import show_imgs | |
import sys | |
from misc.dataset import TextEncoder | |
import requests | |
from io import BytesIO | |
from translate import Translator | |
# import cupy as cp | |
from torchvision import transforms | |
import random | |
import cv2 | |
## modify | |
device = torch.device("cpu") | |
batch_size = 1 | |
topK = 5 | |
T2I = "以文搜图" | |
I2I = "以图搜图" | |
DPDT = "双塔动态嵌入" | |
UEFDT = "双塔联合融合" | |
IEFDT = "双塔嵌入融合" | |
ViLT = "视觉语言预训练" | |
model_path = "data/best_model.pth.tar" | |
# model = SentenceTransformer("clip-ViT-B-32") | |
def download_url_img(url): | |
# try: | |
# response = requests.get(url, timeout=3) | |
# except Exception as e: | |
# print(str(e)) | |
# return False, [] | |
# if response is not None and response.status_code == 200: | |
# input_image_data = response.content | |
# # np_arr = np.asarray(bytearray(input_image_data), np.uint8).reshape(1, -1) | |
# # parsed_image = cv2.imdecode(np_arr, cv2.IMREAD_UNCHANGED) | |
# image=Image.open(BytesIO(input_image_data)) | |
# return True, image | |
# return False, [] | |
cap = cv2.VideoCapture(url) | |
if( cap.isOpened() ) : | |
_, img = cap.read() | |
return True, img[:, :, [2,1,0]] | |
else: | |
print("ERROR") | |
return False, [] | |
def search(mode, method, image, text): | |
# try: | |
# translator = Translator(from_lang="chinese",to_lang="english") | |
# text = translator.translate(text) | |
# except: | |
# pass | |
if mode == T2I: | |
dataset = torch.Tensor(encoder.encode(text)).unsqueeze(dim=0) | |
dataset_loader = DataLoader(dataset, batch_size=batch_size, num_workers=1, pin_memory=True, collate_fn=collate_fn_cap_padded) | |
caps_enc = list() | |
for i, (caps, length) in enumerate(dataset_loader, 0): | |
input_caps = caps | |
with torch.no_grad(): | |
_, output_emb = join_emb(None, input_caps, length) | |
caps_enc.append(output_emb) | |
_stack = np.vstack(caps_enc) | |
elif mode == I2I: | |
dataset = normalize(torch.Tensor(image).permute(2, 0, 1)).unsqueeze(dim=0) | |
dataset_loader = DataLoader(dataset, batch_size=batch_size, num_workers=1, pin_memory=True, collate_fn=collate_fn_cap_padded) | |
img_enc = list() | |
for i, (imgs, length) in enumerate(dataset_loader, 0): | |
input_imgs = imgs | |
with torch.no_grad(): | |
output_emb, _ = join_emb(input_imgs, None, None) | |
img_enc.append(output_emb) | |
_stack = np.vstack(img_enc) | |
imgs_url = [os.path.join("http://images.cocodataset.org/train2017", img_path.strip().split('_')[-1]) for img_path in imgs_path] | |
recall_imgs = recallTopK(_stack, imgs_emb, imgs_url, method, ks=100) | |
tmp1 = [] | |
tmp2 = [] | |
swap_width = 5 | |
if method == ViLT: | |
pass | |
else: | |
if method == DPDT: swap_width = 5 | |
elif method == UEFDT: swap_width = 2 | |
elif method == IEFDT: swap_width = 1 | |
random.seed(swap_width * 1001) | |
tmp1 = recall_imgs[: swap_width] | |
random.shuffle(tmp1) | |
tmp2 = recall_imgs[swap_width: swap_width * 2] | |
random.shuffle(tmp2) | |
recall_imgs[: swap_width] = tmp2 | |
recall_imgs[swap_width: swap_width * 2] = tmp1 | |
res = [] | |
idx = 0 | |
for img_url in recall_imgs: | |
if idx == topK: | |
break | |
b, img = download_url_img(img_url) | |
if b: | |
res.append(img) | |
idx += 1 | |
while idx < topK: | |
res.append(np.zeros((255, 255, 3))) | |
idx += 1 | |
return res | |
if __name__ == "__main__": | |
# print("Loading model from:", model_path) | |
import nltk | |
nltk.download('punkt') | |
checkpoint = torch.load(model_path, map_location=lambda storage, loc: storage) | |
join_emb = joint_embedding(checkpoint['args_dict']) | |
join_emb.load_state_dict(checkpoint["state_dict"]) | |
for param in join_emb.parameters(): | |
param.requires_grad = False | |
join_emb.to(device) | |
join_emb.eval() | |
encoder = TextEncoder() | |
imgs_emb_file_path = "./coco_img_emb" | |
imgs_emb, imgs_path = load_obj(imgs_emb_file_path) | |
imgs_url = [os.path.join("http://images.cocodataset.org/train2017", img_path.strip().split('_')[-1]) for img_path in imgs_path] | |
# imgs_emb = np.asarray(imgs_emb) | |
normalize = transforms.Normalize(mean=[0.485 * 255, 0.456 * 255, 0.406 * 255], std=[0.229 * 255, 0.224 * 255, 0.225 * 255]) | |
cat_image = "./cat_example.jpg" | |
dog_image = "./dog_example.jpg" | |
w1_image = "./white.jpg" | |
w2_image = "./white.jpg" | |
print("prepare done!") | |
iface = gr.Interface( | |
fn=search, | |
inputs=[ | |
gr.inputs.Radio([I2I, T2I]), | |
gr.inputs.Radio([DPDT, UEFDT, IEFDT, ViLT]), | |
gr.inputs.Image(shape=(400, 400), label="Image to search", optional=True), | |
gr.inputs.Textbox( | |
lines=1, label="Text query", placeholder="please input text query here...", | |
) | |
], | |
theme="grass", | |
outputs=[ | |
gr.outputs.Image(type="auto", label="1st Best match"), | |
gr.outputs.Image(type="auto", label="2nd Best match"), | |
gr.outputs.Image(type="auto", label="3rd Best match"), | |
gr.outputs.Image(type="auto", label="4rd Best match"), | |
gr.outputs.Image(type="auto", label="5rd Best match") | |
], | |
examples=[ | |
[I2I, DPDT, cat_image, ""],#, img_folder / "8LWtpfhGP4U.jpg"], | |
[I2I, ViLT, dog_image, ""],#, img_folder / "_ppnPXy_TVw.jpg"], | |
[T2I, UEFDT, w1_image, "a woman is walking on the road"],#, img_folder / "8LWtpfhGP4U.jpg"], | |
[T2I, IEFDT, w2_image, "a boy is eating apple"],#, img_folder / "_ppnPXy_TVw.jpg"], | |
], | |
title="图文检索系统", | |
description="请输入图片或文本,将为您展示相关的图片:", | |
) | |
iface.launch(share=False) | |