Spaces:
Runtime error
Runtime error
import pickle | |
from pathlib import Path | |
import gradio as gr | |
import torch | |
# from loguru import logger | |
from PIL import Image | |
# from sentence_transformers import SentenceTransformer, util | |
import os | |
import argparse | |
import re | |
import time | |
import numpy as np | |
from numpy.__config__ import show | |
import torch | |
from misc.model import img_embedding, joint_embedding | |
from torch.utils.data import DataLoader, dataset | |
from misc.dataset import TextDataset | |
from misc.utils import collate_fn_cap_padded | |
from torch.utils.data import DataLoader | |
from misc.utils import load_obj | |
from misc.evaluation import recallTopK | |
from misc.utils import show_imgs | |
import sys | |
from misc.dataset import TextEncoder | |
import requests | |
from io import BytesIO | |
from translate import Translator | |
device = torch.device("cpu") | |
batch_size = 1 | |
topK = 5 | |
T2I = "Text 2 Image" | |
I2I = "Image 2 Image" | |
model_path = "data/best_model.pth.tar" | |
# model = SentenceTransformer("clip-ViT-B-32") | |
def download_url_img(url): | |
try: | |
response = requests.get(url, timeout=3) | |
except Exception as e: | |
print(str(e)) | |
return False, [] | |
if response is not None and response.status_code == 200: | |
input_image_data = response.content | |
image=Image.open(BytesIO(input_image_data)) | |
return True, image | |
return False, [] | |
def search(mode, text): | |
# translator = Translator(from_lang="chinese",to_lang="english") | |
# text = translator.translate(text) | |
dataset = torch.Tensor(encoder.encode(text)).unsqueeze(dim=0) | |
dataset_loader = DataLoader(dataset, batch_size=batch_size, num_workers=1, pin_memory=True, collate_fn=collate_fn_cap_padded) | |
caps_enc = list() | |
for i, (caps, length) in enumerate(dataset_loader, 0): | |
input_caps = caps.to(device) | |
with torch.no_grad(): | |
_, output_emb = join_emb(None, input_caps, length) | |
caps_enc.append(output_emb.cpu().data.numpy()) | |
caps_stack = np.vstack(caps_enc) | |
imgs_url = [os.path.join("http://images.cocodataset.org/train2017", img_path.strip().split('_')[-1]) for img_path in imgs_path] | |
recall_imgs = recallTopK(caps_stack, imgs_emb, imgs_url, ks=100) | |
# Cat image downloaded from https://www.flickr.com/photos/blacktigersdream/23119711630 | |
# cat_image = "./cat_example.jpg" | |
# Dog example downloaded from https://upload.wikimedia.org/wikipedia/commons/1/18/Dog_Breeds.jpg | |
# dog_image = "./dog_example.jpg" | |
res = [] | |
idx = 0 | |
for img_url in recall_imgs: | |
if idx == topK: | |
break | |
b, img = download_url_img(img_url) | |
if b: | |
res.append(img) | |
idx += 1 | |
return res | |
if __name__ == "__main__": | |
# print("Loading model from:", model_path) | |
checkpoint = torch.load(model_path, map_location=lambda storage, loc: storage) | |
join_emb = joint_embedding(checkpoint['args_dict']) | |
join_emb.load_state_dict(checkpoint["state_dict"]) | |
for param in join_emb.parameters(): | |
param.requires_grad = False | |
join_emb.to(device) | |
join_emb.eval() | |
encoder = TextEncoder() | |
imgs_emb_file_path = "./coco_img_emb" | |
imgs_emb, imgs_path = load_obj(imgs_emb_file_path) | |
print("prepare done!") | |
iface = gr.Interface( | |
fn=search, | |
inputs=[ | |
gr.inputs.Radio([T2I]), | |
gr.inputs.Textbox( | |
lines=1, label="Text query", placeholder="Introduce the search text...", | |
), | |
], | |
theme="grass", | |
outputs=[ | |
gr.outputs.Image(type="auto", label="1st Best match"), | |
gr.outputs.Image(type="auto", label="2nd Best match"), | |
gr.outputs.Image(type="auto", label="3rd Best match"), | |
gr.outputs.Image(type="auto", label="4rd Best match"), | |
gr.outputs.Image(type="auto", label="5rd Best match") | |
], | |
title="HUST毕业设计-图文检索系统", | |
description="请输入图片或文本,将为您展示相关的图片:", | |
) | |
iface.launch(share=False) | |