import pickle from pathlib import Path import gradio as gr import torch # from loguru import logger from PIL import Image # from sentence_transformers import SentenceTransformer, util import os import argparse import re import time import numpy as np from numpy.__config__ import show import torch from misc.model import img_embedding, joint_embedding from torch.utils.data import DataLoader, dataset from misc.dataset import TextDataset from misc.utils import collate_fn_cap_padded from torch.utils.data import DataLoader from misc.utils import load_obj from misc.evaluation import recallTopK from misc.utils import show_imgs import sys from misc.dataset import TextEncoder import requests from io import BytesIO from translate import Translator from torchvision import transforms device = torch.device("cpu") batch_size = 1 topK = 5 T2I = "Text 2 Image" I2I = "Image 2 Image" model_path = "data/best_model.pth.tar" # model = SentenceTransformer("clip-ViT-B-32") def download_url_img(url): try: response = requests.get(url, timeout=3) except Exception as e: print(str(e)) return False, [] if response is not None and response.status_code == 200: input_image_data = response.content image=Image.open(BytesIO(input_image_data)) return True, image return False, [] def search(mode, image, text): translator = Translator(from_lang="chinese",to_lang="english") text = translator.translate(text) if mode == T2I: dataset = torch.Tensor(encoder.encode(text)).unsqueeze(dim=0) dataset_loader = DataLoader(dataset, batch_size=batch_size, num_workers=1, pin_memory=True, collate_fn=collate_fn_cap_padded) caps_enc = list() for i, (caps, length) in enumerate(dataset_loader, 0): input_caps = caps with torch.no_grad(): _, output_emb = join_emb(None, input_caps, length) caps_enc.append(output_emb) _stack = np.vstack(caps_enc) elif mode == I2I: dataset = normalize(torch.Tensor(image).permute(2, 0, 1)).unsqueeze(dim=0) dataset_loader = DataLoader(dataset, batch_size=batch_size, num_workers=1, pin_memory=True, collate_fn=collate_fn_cap_padded) img_enc = list() for i, (imgs, length) in enumerate(dataset_loader, 0): input_imgs = imgs with torch.no_grad(): output_emb, _ = join_emb(input_imgs, None, None) img_enc.append(output_emb) _stack = np.vstack(img_enc) recall_imgs = recallTopK(_stack, imgs_emb, imgs_url, ks=100) # Cat image downloaded from https://www.flickr.com/photos/blacktigersdream/23119711630 # cat_image = "./cat_example.jpg" # Dog example downloaded from https://upload.wikimedia.org/wikipedia/commons/1/18/Dog_Breeds.jpg # dog_image = "./dog_example.jpg" res = [] idx = 0 for img_url in recall_imgs: if idx == topK: break b, img = download_url_img(img_url) if b: res.append(img) idx += 1 return res if __name__ == "__main__": import nltk nltk.download('punkt') # print("Loading model from:", model_path) checkpoint = torch.load(model_path, map_location=lambda storage, loc: storage) join_emb = joint_embedding(checkpoint['args_dict']) join_emb.load_state_dict(checkpoint["state_dict"]) for param in join_emb.parameters(): param.requires_grad = False join_emb.to(device) join_emb.eval() encoder = TextEncoder() imgs_emb_file_path = "./coco_img_emb" imgs_emb, imgs_path = load_obj(imgs_emb_file_path) imgs_url = [os.path.join("http://images.cocodataset.org/train2017", img_path.strip().split('_')[-1]) for img_path in imgs_path] normalize = transforms.Normalize(mean=[0.485 * 255, 0.456 * 255, 0.406 * 255], std=[0.229 * 255, 0.224 * 255, 0.225 * 255]) print("prepare done!") iface = gr.Interface( fn=search, inputs=[ gr.inputs.Radio([I2I, T2I]), gr.inputs.Image(shape=(400, 400), label="Image to search", optional=True), gr.inputs.Textbox( lines=1, label="Text query", placeholder="Introduce the search text...", ), ], theme="grass", outputs=[ gr.outputs.Image(type="auto", label="1st Best match"), gr.outputs.Image(type="auto", label="2nd Best match"), gr.outputs.Image(type="auto", label="3rd Best match"), gr.outputs.Image(type="auto", label="4rd Best match"), gr.outputs.Image(type="auto", label="5rd Best match") ], title="HUST毕业设计-图文检索系统", description="请输入图片或文本,将为您展示相关的图片:", ) iface.launch(share=False)