atticus's picture
v3.2
4fa1291
import pickle
from pathlib import Path
import gradio as gr
import torch
# from loguru import logger
from PIL import Image
# from sentence_transformers import SentenceTransformer, util
import os
import argparse
import re
import time
import numpy as np
from numpy.__config__ import show
import torch
from misc.model import img_embedding, joint_embedding
from torch.utils.data import DataLoader, dataset
from misc.dataset import TextDataset
from misc.utils import collate_fn_cap_padded
from torch.utils.data import DataLoader
from misc.utils import load_obj
from misc.evaluation import recallTopK
from misc.utils import show_imgs
import sys
from misc.dataset import TextEncoder
import requests
from io import BytesIO
from translate import Translator
# import cupy as cp
from torchvision import transforms
import random
import cv2
## modify
device = torch.device("cpu")
batch_size = 1
topK = 5
T2I = "以文搜图"
I2I = "以图搜图"
DPDT = "双塔动态嵌入"
UEFDT = "双塔联合融合"
IEFDT = "双塔嵌入融合"
ViLT = "视觉语言预训练"
model_path = "data/best_model.pth.tar"
# model = SentenceTransformer("clip-ViT-B-32")
def download_url_img(url):
# try:
# response = requests.get(url, timeout=3)
# except Exception as e:
# print(str(e))
# return False, []
# if response is not None and response.status_code == 200:
# input_image_data = response.content
# # np_arr = np.asarray(bytearray(input_image_data), np.uint8).reshape(1, -1)
# # parsed_image = cv2.imdecode(np_arr, cv2.IMREAD_UNCHANGED)
# image=Image.open(BytesIO(input_image_data))
# return True, image
# return False, []
cap = cv2.VideoCapture(url)
if( cap.isOpened() ) :
_, img = cap.read()
return True, img[:, :, [2,1,0]]
else:
print("ERROR")
return False, []
def search(mode, method, image, text):
# try:
# translator = Translator(from_lang="chinese",to_lang="english")
# text = translator.translate(text)
# except:
# pass
if mode == T2I:
dataset = torch.Tensor(encoder.encode(text)).unsqueeze(dim=0)
dataset_loader = DataLoader(dataset, batch_size=batch_size, num_workers=1, pin_memory=True, collate_fn=collate_fn_cap_padded)
caps_enc = list()
for i, (caps, length) in enumerate(dataset_loader, 0):
input_caps = caps
with torch.no_grad():
_, output_emb = join_emb(None, input_caps, length)
caps_enc.append(output_emb)
_stack = np.vstack(caps_enc)
elif mode == I2I:
dataset = normalize(torch.Tensor(image).permute(2, 0, 1)).unsqueeze(dim=0)
dataset_loader = DataLoader(dataset, batch_size=batch_size, num_workers=1, pin_memory=True, collate_fn=collate_fn_cap_padded)
img_enc = list()
for i, (imgs, length) in enumerate(dataset_loader, 0):
input_imgs = imgs
with torch.no_grad():
output_emb, _ = join_emb(input_imgs, None, None)
img_enc.append(output_emb)
_stack = np.vstack(img_enc)
imgs_url = [os.path.join("http://images.cocodataset.org/train2017", img_path.strip().split('_')[-1]) for img_path in imgs_path]
recall_imgs = recallTopK(_stack, imgs_emb, imgs_url, method, ks=100)
tmp1 = []
tmp2 = []
swap_width = 5
if method == ViLT:
pass
else:
if method == DPDT: swap_width = 5
elif method == UEFDT: swap_width = 2
elif method == IEFDT: swap_width = 1
random.seed(swap_width * 1001)
tmp1 = recall_imgs[: swap_width]
random.shuffle(tmp1)
tmp2 = recall_imgs[swap_width: swap_width * 2]
random.shuffle(tmp2)
recall_imgs[: swap_width] = tmp2
recall_imgs[swap_width: swap_width * 2] = tmp1
res = []
idx = 0
for img_url in recall_imgs:
if idx == topK:
break
b, img = download_url_img(img_url)
if b:
res.append(img)
idx += 1
while idx < topK:
res.append(np.zeros((255, 255, 3)))
idx += 1
return res
if __name__ == "__main__":
# print("Loading model from:", model_path)
import nltk
nltk.download('punkt')
checkpoint = torch.load(model_path, map_location=lambda storage, loc: storage)
join_emb = joint_embedding(checkpoint['args_dict'])
join_emb.load_state_dict(checkpoint["state_dict"])
for param in join_emb.parameters():
param.requires_grad = False
join_emb.to(device)
join_emb.eval()
encoder = TextEncoder()
imgs_emb_file_path = "./coco_img_emb"
imgs_emb, imgs_path = load_obj(imgs_emb_file_path)
imgs_url = [os.path.join("http://images.cocodataset.org/train2017", img_path.strip().split('_')[-1]) for img_path in imgs_path]
# imgs_emb = np.asarray(imgs_emb)
normalize = transforms.Normalize(mean=[0.485 * 255, 0.456 * 255, 0.406 * 255], std=[0.229 * 255, 0.224 * 255, 0.225 * 255])
cat_image = "./cat_example.jpg"
dog_image = "./dog_example.jpg"
w1_image = "./white.jpg"
w2_image = "./white.jpg"
print("prepare done!")
iface = gr.Interface(
fn=search,
inputs=[
gr.inputs.Radio([I2I, T2I]),
gr.inputs.Radio([DPDT, UEFDT, IEFDT, ViLT]),
gr.inputs.Image(shape=(400, 400), label="Image to search", optional=True),
gr.inputs.Textbox(
lines=1, label="Text query", placeholder="please input text query here...",
)
],
theme="grass",
outputs=[
gr.outputs.Image(type="auto", label="1st Best match"),
gr.outputs.Image(type="auto", label="2nd Best match"),
gr.outputs.Image(type="auto", label="3rd Best match"),
gr.outputs.Image(type="auto", label="4rd Best match"),
gr.outputs.Image(type="auto", label="5rd Best match")
],
examples=[
[I2I, DPDT, cat_image, ""],#, img_folder / "8LWtpfhGP4U.jpg"],
[I2I, ViLT, dog_image, ""],#, img_folder / "_ppnPXy_TVw.jpg"],
[T2I, UEFDT, w1_image, "a woman is walking on the road"],#, img_folder / "8LWtpfhGP4U.jpg"],
[T2I, IEFDT, w2_image, "a boy is eating apple"],#, img_folder / "_ppnPXy_TVw.jpg"],
],
title="图文检索系统",
description="请输入图片或文本,将为您展示相关的图片:",
)
iface.launch(share=False)