Spaces:
Build error
Build error
File size: 5,209 Bytes
950c874 2836e50 950c874 262573b 950c874 953580c 1e7fce7 5be980c 097b01d dfbeba0 953580c 950c874 0550960 20307cf 0550960 950c874 953580c 950c874 0550960 9666011 5be980c 3b4c7a3 1e7fce7 3b4c7a3 1e7fce7 3b4c7a3 262573b 5be980c 953580c 950c874 953580c 362a148 953580c 950c874 953580c 2836e50 953580c 2836e50 953580c 3b4c7a3 1e7fce7 5be980c 1e7fce7 953580c 950c874 3b4c7a3 20307cf 3405edf 950c874 5be980c 950c874 5be980c 20307cf 5be980c 950c874 f733bb1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 |
import pickle
from pathlib import Path
import gradio as gr
import torch
# from loguru import logger
from PIL import Image
# from sentence_transformers import SentenceTransformer, util
import os
import argparse
import re
import time
import numpy as np
from numpy.__config__ import show
import torch
from misc.model import img_embedding, joint_embedding
from torch.utils.data import DataLoader, dataset
from misc.dataset import TextDataset
from misc.utils import collate_fn_cap_padded
from torch.utils.data import DataLoader
from misc.utils import load_obj
from misc.evaluation import recallTopK
from scripts.postprocess import postprocess
from misc.utils import show_imgs
import sys
from misc.dataset import TextEncoder
import requests
from io import BytesIO
from translate import Translator
from torchvision import transforms
import random
##
device = torch.device("cpu")
batch_size = 1
topK = 5
T2I = "以文搜图"
I2I = "以图搜图"
DPDT = "双塔动态池化"
UEFDT = "双塔联合融合"
IEFDT = "双塔嵌入融合"
ViLT = "视觉语言预训练"
model_path = "data/best_model.pth.tar"
# model = SentenceTransformer("clip-ViT-B-32")
def download_url_img(url):
try:
response = requests.get(url, timeout=3)
except Exception as e:
print(str(e))
return False, []
if response is not None and response.status_code == 200:
input_image_data = response.content
image=Image.open(BytesIO(input_image_data))
return True, image
return False, []
def search(mode, method, image, text):
# translator = Translator(from_lang="chinese",to_lang="english")
# text = translator.translate(text)
if mode == T2I:
dataset = torch.Tensor(encoder.encode(text)).unsqueeze(dim=0)
dataset_loader = DataLoader(dataset, batch_size=batch_size, num_workers=1, pin_memory=True, collate_fn=collate_fn_cap_padded)
caps_enc = list()
for i, (caps, length) in enumerate(dataset_loader, 0):
input_caps = caps
with torch.no_grad():
_, output_emb = join_emb(None, input_caps, length)
caps_enc.append(output_emb)
_stack = np.vstack(caps_enc)
elif mode == I2I:
dataset = normalize(torch.Tensor(image).permute(2, 0, 1)).unsqueeze(dim=0)
dataset_loader = DataLoader(dataset, batch_size=batch_size, num_workers=1, pin_memory=True, collate_fn=collate_fn_cap_padded)
img_enc = list()
for i, (imgs, length) in enumerate(dataset_loader, 0):
input_imgs = imgs
with torch.no_grad():
output_emb, _ = join_emb(input_imgs, None, None)
img_enc.append(output_emb)
_stack = np.vstack(img_enc)
recall_imgs = recallTopK(_stack, imgs_emb, imgs_url, ks=100)
postprocess(recall_imgs)
res = []
idx = 0
for img_url in recall_imgs:
if idx == topK:
break
b, img = download_url_img(img_url)
if b:
res.append(img)
idx += 1
return res
if __name__ == "__main__":
import nltk
nltk.download('punkt')
# print("Loading model from:", model_path)
checkpoint = torch.load(model_path, map_location=lambda storage, loc: storage)
join_emb = joint_embedding(checkpoint['args_dict'])
join_emb.load_state_dict(checkpoint["state_dict"])
for param in join_emb.parameters():
param.requires_grad = False
join_emb.to(device)
join_emb.eval()
encoder = TextEncoder()
imgs_emb_file_path = "./coco_img_emb"
imgs_emb, imgs_path = load_obj(imgs_emb_file_path)
imgs_url = [os.path.join("http://images.cocodataset.org/train2017", img_path.strip().split('_')[-1]) for img_path in imgs_path]
normalize = transforms.Normalize(mean=[0.485 * 255, 0.456 * 255, 0.406 * 255], std=[0.229 * 255, 0.224 * 255, 0.225 * 255])
cat_image = "./cat_example.jpg"
dog_image = "./dog_example.jpg"
w1_image = "./white.jpg"
w2_image = "./white.jpg"
print("prepare done!")
iface = gr.Interface(
fn=search,
inputs=[
gr.inputs.Radio([I2I, T2I]),
gr.inputs.Radio([DPDT, UEFDT, IEFDT, ViLT]),
gr.inputs.Image(shape=(400, 400), label="Image to search", optional=True),
gr.inputs.Textbox(
lines=1, label="Text query", placeholder="please input text query here...",
),
],
theme="grass",
outputs=[
gr.outputs.Image(type="auto", label="1st Best match"),
gr.outputs.Image(type="auto", label="2nd Best match"),
gr.outputs.Image(type="auto", label="3rd Best match"),
gr.outputs.Image(type="auto", label="4rd Best match"),
gr.outputs.Image(type="auto", label="5rd Best match")
],
examples=[
[I2I, DPDT, cat_image, ""],
[I2I, ViLT, dog_image, ""],
[T2I, UEFDT, w1_image, "a woman is walking on the road"],
[T2I, IEFDT, w2_image, "a boy is eating apple"],
],
title="HUST毕业设计-图文检索系统",
description="请输入图片或文本,将为您展示相关的图片:",
)
iface.launch(share=False)
|