import pickle from pathlib import Path import gradio as gr import torch # from loguru import logger from PIL import Image # from sentence_transformers import SentenceTransformer, util import os import argparse import re import time import numpy as np from numpy.__config__ import show import torch from misc.model import img_embedding, joint_embedding from torch.utils.data import DataLoader, dataset from misc.dataset import TextDataset from misc.utils import collate_fn_cap_padded from torch.utils.data import DataLoader from misc.utils import load_obj from misc.evaluation import recallTopK from misc.utils import show_imgs import sys from misc.dataset import TextEncoder import requests import cv2 from io import BytesIO from translate import Translator import cupy as cp device = torch.device("cuda") batch_size = 1 topK = 5 T2I = "Text 2 Image" I2I = "Image 2 Image" model_path = "data/best_model.pth.tar" # model = SentenceTransformer("clip-ViT-B-32") img_folder = Path("./photos/") # start def download_url_img(url): try: response = requests.get(url, timeout=3) except Exception as e: print(str(e)) return False, [] if response is not None and response.status_code == 200: input_image_data = response.content # np_arr = np.asarray(bytearray(input_image_data), np.uint8).reshape(1, -1) # parsed_image = cv2.imdecode(np_arr, cv2.IMREAD_UNCHANGED) image=Image.open(BytesIO(input_image_data)) return True, image return False, [] def search(mode, text): # translator = Translator(from_lang="chinese",to_lang="english") # text = translator.translate(text) dataset = torch.Tensor(encoder.encode(text)).unsqueeze(dim=0) dataset_loader = DataLoader(dataset, batch_size=batch_size, num_workers=1, pin_memory=True, collate_fn=collate_fn_cap_padded) caps_enc = list() for _, (caps, length) in enumerate(dataset_loader, 0): input_caps = caps.to(device) with torch.no_grad(): _, caps_emb = join_emb(None, input_caps, length) caps_enc.append(caps_emb) caps_stack = cp.vstack(caps_enc) imgs_url = [os.path.join("http://images.cocodataset.org/train2017", img_path.strip().split('_')[-1]) for img_path in imgs_path] recall_imgs = recallTopK(caps_stack, imgs_emb, imgs_url, ks=100) # Cat image downloaded from https://www.flickr.com/photos/blacktigersdream/23119711630 # cat_image = "./cat_example.jpg" # Dog example downloaded from https://upload.wikimedia.org/wikipedia/commons/1/18/Dog_Breeds.jpg # dog_image = "./dog_example.jpg" res = [] idx = 0 for img_url in recall_imgs: if idx == topK: break b, img = download_url_img(img_url) if b: res.append(img) idx += 1 return res if __name__ == "__main__": # print("Loading model from:", model_path) checkpoint = torch.load(model_path, map_location=lambda storage, loc: storage) join_emb = joint_embedding(checkpoint['args_dict']) join_emb.load_state_dict(checkpoint["state_dict"]) for param in join_emb.parameters(): param.requires_grad = False join_emb.to(device) join_emb.eval() encoder = TextEncoder() imgs_emb_file_path = "./coco_img_emb" imgs_emb, imgs_path = load_obj(imgs_emb_file_path) imgs_emb = cp.asarray(imgs_emb) print("prepare done!") iface = gr.Interface( fn=search, inputs=[ gr.inputs.Radio([T2I]), gr.inputs.Textbox( lines=1, label="Text query", placeholder="Introduce the search text...", ), ], theme="grass", outputs=[ gr.outputs.Image(type="auto", label="1st Best match"), gr.outputs.Image(type="auto", label="2nd Best match"), gr.outputs.Image(type="auto", label="3rd Best match"), gr.outputs.Image(type="auto", label="4rd Best match"), gr.outputs.Image(type="auto", label="5rd Best match") ], title="HUST毕业设计-图文检索系统", description="请输入图片或文本,将为您展示相关的图片:", ) iface.launch(share=True)