Spaces:
Build error
Build error
File size: 4,206 Bytes
950c874 2836e50 950c874 953580c 950c874 953580c 950c874 953580c f763a16 950c874 953580c 950c874 953580c 950c874 953580c 950c874 953580c 950c874 953580c 950c874 953580c 950c874 953580c 950c874 953580c 950c874 953580c 950c874 953580c 2836e50 953580c 2836e50 953580c 950c874 953580c 950c874 953580c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 |
import pickle
from pathlib import Path
import gradio as gr
import torch
# from loguru import logger
from PIL import Image
# from sentence_transformers import SentenceTransformer, util
import os
import argparse
import re
import time
import numpy as np
from numpy.__config__ import show
import torch
from misc.model import img_embedding, joint_embedding
from torch.utils.data import DataLoader, dataset
from misc.dataset import TextDataset
from misc.utils import collate_fn_cap_padded
from torch.utils.data import DataLoader
from misc.utils import load_obj
from misc.evaluation import recallTopK
from misc.utils import show_imgs
import sys
from misc.dataset import TextEncoder
import requests
import cv2
from io import BytesIO
from translate import Translator
import cupy as cp
device = torch.device("cuda")
batch_size = 1
topK = 5
T2I = "Text 2 Image"
I2I = "Image 2 Image"
model_path = "data/best_model.pth.tar"
# model = SentenceTransformer("clip-ViT-B-32")
img_folder = Path("./photos/")
# start
def download_url_img(url):
try:
response = requests.get(url, timeout=3)
except Exception as e:
print(str(e))
return False, []
if response is not None and response.status_code == 200:
input_image_data = response.content
# np_arr = np.asarray(bytearray(input_image_data), np.uint8).reshape(1, -1)
# parsed_image = cv2.imdecode(np_arr, cv2.IMREAD_UNCHANGED)
image=Image.open(BytesIO(input_image_data))
return True, image
return False, []
def search(mode, text):
# translator = Translator(from_lang="chinese",to_lang="english")
# text = translator.translate(text)
dataset = torch.Tensor(encoder.encode(text)).unsqueeze(dim=0)
dataset_loader = DataLoader(dataset, batch_size=batch_size, num_workers=1, pin_memory=True, collate_fn=collate_fn_cap_padded)
caps_enc = list()
for _, (caps, length) in enumerate(dataset_loader, 0):
input_caps = caps.to(device)
with torch.no_grad():
_, caps_emb = join_emb(None, input_caps, length)
caps_enc.append(caps_emb)
caps_stack = cp.vstack(caps_enc)
imgs_url = [os.path.join("http://images.cocodataset.org/train2017", img_path.strip().split('_')[-1]) for img_path in imgs_path]
recall_imgs = recallTopK(caps_stack, imgs_emb, imgs_url, ks=100)
# Cat image downloaded from https://www.flickr.com/photos/blacktigersdream/23119711630
# cat_image = "./cat_example.jpg"
# Dog example downloaded from https://upload.wikimedia.org/wikipedia/commons/1/18/Dog_Breeds.jpg
# dog_image = "./dog_example.jpg"
res = []
idx = 0
for img_url in recall_imgs:
if idx == topK:
break
b, img = download_url_img(img_url)
if b:
res.append(img)
idx += 1
return res
if __name__ == "__main__":
# print("Loading model from:", model_path)
checkpoint = torch.load(model_path, map_location=lambda storage, loc: storage)
join_emb = joint_embedding(checkpoint['args_dict'])
join_emb.load_state_dict(checkpoint["state_dict"])
for param in join_emb.parameters():
param.requires_grad = False
join_emb.to(device)
join_emb.eval()
encoder = TextEncoder()
imgs_emb_file_path = "./coco_img_emb"
imgs_emb, imgs_path = load_obj(imgs_emb_file_path)
imgs_emb = cp.asarray(imgs_emb)
print("prepare done!")
iface = gr.Interface(
fn=search,
inputs=[
gr.inputs.Radio([T2I]),
gr.inputs.Textbox(
lines=1, label="Text query", placeholder="Introduce the search text...",
),
],
theme="grass",
outputs=[
gr.outputs.Image(type="auto", label="1st Best match"),
gr.outputs.Image(type="auto", label="2nd Best match"),
gr.outputs.Image(type="auto", label="3rd Best match"),
gr.outputs.Image(type="auto", label="4rd Best match"),
gr.outputs.Image(type="auto", label="5rd Best match")
],
title="HUST毕业设计-图文检索系统",
description="请输入图片或文本,将为您展示相关的图片:",
)
iface.launch(share=True)
|