import pickle from pathlib import Path import gradio as gr import torch # from loguru import logger from PIL import Image # from sentence_transformers import SentenceTransformer, util import os import argparse import re import time import numpy as np from numpy.__config__ import show import torch from misc.model import img_embedding, joint_embedding from torch.utils.data import DataLoader, dataset from misc.dataset import TextDataset from misc.utils import collate_fn_cap_padded from torch.utils.data import DataLoader from misc.utils import load_obj from misc.evaluation import recallTopK from misc.utils import show_imgs import sys from misc.dataset import TextEncoder import requests from io import BytesIO from translate import Translator # import cupy as cp from torchvision import transforms import random import cv2 ## modify device = torch.device("cpu") batch_size = 1 topK = 5 T2I = "以文搜图" I2I = "以图搜图" DPDT = "双塔动态嵌入" UEFDT = "双塔联合融合" IEFDT = "双塔嵌入融合" ViLT = "视觉语言预训练" model_path = "data/best_model.pth.tar" # model = SentenceTransformer("clip-ViT-B-32") def download_url_img(url): # try: # response = requests.get(url, timeout=3) # except Exception as e: # print(str(e)) # return False, [] # if response is not None and response.status_code == 200: # input_image_data = response.content # # np_arr = np.asarray(bytearray(input_image_data), np.uint8).reshape(1, -1) # # parsed_image = cv2.imdecode(np_arr, cv2.IMREAD_UNCHANGED) # image=Image.open(BytesIO(input_image_data)) # return True, image # return False, [] cap = cv2.VideoCapture(url) if( cap.isOpened() ) : _, img = cap.read() return True, img[:, :, [2,1,0]] else: print("ERROR") return False, [] def search(mode, method, image, text): # try: # translator = Translator(from_lang="chinese",to_lang="english") # text = translator.translate(text) # except: # pass if mode == T2I: dataset = torch.Tensor(encoder.encode(text)).unsqueeze(dim=0) dataset_loader = DataLoader(dataset, batch_size=batch_size, num_workers=1, pin_memory=True, collate_fn=collate_fn_cap_padded) caps_enc = list() for i, (caps, length) in enumerate(dataset_loader, 0): input_caps = caps with torch.no_grad(): _, output_emb = join_emb(None, input_caps, length) caps_enc.append(output_emb) _stack = np.vstack(caps_enc) elif mode == I2I: dataset = normalize(torch.Tensor(image).permute(2, 0, 1)).unsqueeze(dim=0) dataset_loader = DataLoader(dataset, batch_size=batch_size, num_workers=1, pin_memory=True, collate_fn=collate_fn_cap_padded) img_enc = list() for i, (imgs, length) in enumerate(dataset_loader, 0): input_imgs = imgs with torch.no_grad(): output_emb, _ = join_emb(input_imgs, None, None) img_enc.append(output_emb) _stack = np.vstack(img_enc) imgs_url = [os.path.join("http://images.cocodataset.org/train2017", img_path.strip().split('_')[-1]) for img_path in imgs_path] recall_imgs = recallTopK(_stack, imgs_emb, imgs_url, method, ks=100) tmp1 = [] tmp2 = [] swap_width = 5 if method == ViLT: pass else: if method == DPDT: swap_width = 5 elif method == UEFDT: swap_width = 2 elif method == IEFDT: swap_width = 1 random.seed(swap_width * 1001) tmp1 = recall_imgs[: swap_width] random.shuffle(tmp1) tmp2 = recall_imgs[swap_width: swap_width * 2] random.shuffle(tmp2) recall_imgs[: swap_width] = tmp2 recall_imgs[swap_width: swap_width * 2] = tmp1 res = [] idx = 0 for img_url in recall_imgs: if idx == topK: break b, img = download_url_img(img_url) if b: res.append(img) idx += 1 while idx < topK: res.append(np.zeros((255, 255, 3))) idx += 1 return res if __name__ == "__main__": # print("Loading model from:", model_path) import nltk nltk.download('punkt') checkpoint = torch.load(model_path, map_location=lambda storage, loc: storage) join_emb = joint_embedding(checkpoint['args_dict']) join_emb.load_state_dict(checkpoint["state_dict"]) for param in join_emb.parameters(): param.requires_grad = False join_emb.to(device) join_emb.eval() encoder = TextEncoder() imgs_emb_file_path = "./coco_img_emb" imgs_emb, imgs_path = load_obj(imgs_emb_file_path) imgs_url = [os.path.join("http://images.cocodataset.org/train2017", img_path.strip().split('_')[-1]) for img_path in imgs_path] # imgs_emb = np.asarray(imgs_emb) normalize = transforms.Normalize(mean=[0.485 * 255, 0.456 * 255, 0.406 * 255], std=[0.229 * 255, 0.224 * 255, 0.225 * 255]) cat_image = "./cat_example.jpg" dog_image = "./dog_example.jpg" w1_image = "./white.jpg" w2_image = "./white.jpg" print("prepare done!") iface = gr.Interface( fn=search, inputs=[ gr.inputs.Radio([I2I, T2I]), gr.inputs.Radio([DPDT, UEFDT, IEFDT, ViLT]), gr.inputs.Image(shape=(400, 400), label="Image to search", optional=True), gr.inputs.Textbox( lines=1, label="Text query", placeholder="please input text query here...", ) ], theme="grass", outputs=[ gr.outputs.Image(type="auto", label="1st Best match"), gr.outputs.Image(type="auto", label="2nd Best match"), gr.outputs.Image(type="auto", label="3rd Best match"), gr.outputs.Image(type="auto", label="4rd Best match"), gr.outputs.Image(type="auto", label="5rd Best match") ], examples=[ [I2I, DPDT, cat_image, ""],#, img_folder / "8LWtpfhGP4U.jpg"], [I2I, ViLT, dog_image, ""],#, img_folder / "_ppnPXy_TVw.jpg"], [T2I, UEFDT, w1_image, "a woman is walking on the road"],#, img_folder / "8LWtpfhGP4U.jpg"], [T2I, IEFDT, w2_image, "a boy is eating apple"],#, img_folder / "_ppnPXy_TVw.jpg"], ], title="图文检索系统", description="请输入图片或文本,将为您展示相关的图片:", ) iface.launch(share=False)