atticus's picture
share false
f733bb1
raw history blame
No virus
3.97 kB
import pickle
from pathlib import Path
import gradio as gr
import torch
# from loguru import logger
from PIL import Image
# from sentence_transformers import SentenceTransformer, util
import os
import argparse
import re
import time
import numpy as np
from numpy.__config__ import show
import torch
from misc.model import img_embedding, joint_embedding
from torch.utils.data import DataLoader, dataset
from misc.dataset import TextDataset
from misc.utils import collate_fn_cap_padded
from torch.utils.data import DataLoader
from misc.utils import load_obj
from misc.evaluation import recallTopK
from misc.utils import show_imgs
import sys
from misc.dataset import TextEncoder
import requests
from io import BytesIO
from translate import Translator
device = torch.device("cpu")
batch_size = 1
topK = 5
T2I = "Text 2 Image"
I2I = "Image 2 Image"
model_path = "data/best_model.pth.tar"
# model = SentenceTransformer("clip-ViT-B-32")
def download_url_img(url):
try:
response = requests.get(url, timeout=3)
except Exception as e:
print(str(e))
return False, []
if response is not None and response.status_code == 200:
input_image_data = response.content
image=Image.open(BytesIO(input_image_data))
return True, image
return False, []
def search(mode, text):
# translator = Translator(from_lang="chinese",to_lang="english")
# text = translator.translate(text)
dataset = torch.Tensor(encoder.encode(text)).unsqueeze(dim=0)
dataset_loader = DataLoader(dataset, batch_size=batch_size, num_workers=1, pin_memory=True, collate_fn=collate_fn_cap_padded)
caps_enc = list()
for i, (caps, length) in enumerate(dataset_loader, 0):
input_caps = caps.to(device)
with torch.no_grad():
_, output_emb = join_emb(None, input_caps, length)
caps_enc.append(output_emb.cpu().data.numpy())
caps_stack = np.vstack(caps_enc)
imgs_url = [os.path.join("http://images.cocodataset.org/train2017", img_path.strip().split('_')[-1]) for img_path in imgs_path]
recall_imgs = recallTopK(caps_stack, imgs_emb, imgs_url, ks=100)
# Cat image downloaded from https://www.flickr.com/photos/blacktigersdream/23119711630
# cat_image = "./cat_example.jpg"
# Dog example downloaded from https://upload.wikimedia.org/wikipedia/commons/1/18/Dog_Breeds.jpg
# dog_image = "./dog_example.jpg"
res = []
idx = 0
for img_url in recall_imgs:
if idx == topK:
break
b, img = download_url_img(img_url)
if b:
res.append(img)
idx += 1
return res
if __name__ == "__main__":
# print("Loading model from:", model_path)
checkpoint = torch.load(model_path, map_location=lambda storage, loc: storage)
join_emb = joint_embedding(checkpoint['args_dict'])
join_emb.load_state_dict(checkpoint["state_dict"])
for param in join_emb.parameters():
param.requires_grad = False
join_emb.to(device)
join_emb.eval()
encoder = TextEncoder()
imgs_emb_file_path = "./coco_img_emb"
imgs_emb, imgs_path = load_obj(imgs_emb_file_path)
print("prepare done!")
iface = gr.Interface(
fn=search,
inputs=[
gr.inputs.Radio([T2I]),
gr.inputs.Textbox(
lines=1, label="Text query", placeholder="Introduce the search text...",
),
],
theme="grass",
outputs=[
gr.outputs.Image(type="auto", label="1st Best match"),
gr.outputs.Image(type="auto", label="2nd Best match"),
gr.outputs.Image(type="auto", label="3rd Best match"),
gr.outputs.Image(type="auto", label="4rd Best match"),
gr.outputs.Image(type="auto", label="5rd Best match")
],
title="HUST毕业设计-图文检索系统",
description="请输入图片或文本,将为您展示相关的图片:",
)
iface.launch(share=False)