Spaces:
Build error
Build error
File size: 5,171 Bytes
950c874 2836e50 950c874 953580c 1e7fce7 950c874 dfbeba0 953580c 950c874 0550960 950c874 953580c 950c874 0550960 9666011 0550960 3b4c7a3 1e7fce7 3b4c7a3 1e7fce7 3b4c7a3 953580c 0550960 953580c 950c874 953580c 362a148 953580c 950c874 953580c 2836e50 953580c 2836e50 953580c 3b4c7a3 1e7fce7 953580c 950c874 3b4c7a3 0550960 950c874 0550960 950c874 f733bb1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 |
import pickle
from pathlib import Path
import gradio as gr
import torch
# from loguru import logger
from PIL import Image
# from sentence_transformers import SentenceTransformer, util
import os
import argparse
import re
import time
import numpy as np
from numpy.__config__ import show
import torch
from misc.model import img_embedding, joint_embedding
from torch.utils.data import DataLoader, dataset
from misc.dataset import TextDataset
from misc.utils import collate_fn_cap_padded
from torch.utils.data import DataLoader
from misc.utils import load_obj
from misc.evaluation import recallTopK
from misc.utils import show_imgs
import sys
from misc.dataset import TextEncoder
import requests
from io import BytesIO
from translate import Translator
from torchvision import transforms
device = torch.device("cpu")
batch_size = 1
topK = 5
T2I = "以文搜图"
I2I = "以图搜图"
DDT = "双塔动态嵌入"
UEFDT = "双塔联合融合"
IEFDT = "双塔嵌入融合"
ViLT = "视觉语言预训练"
model_path = "data/best_model.pth.tar"
# model = SentenceTransformer("clip-ViT-B-32")
def download_url_img(url):
try:
response = requests.get(url, timeout=3)
except Exception as e:
print(str(e))
return False, []
if response is not None and response.status_code == 200:
input_image_data = response.content
image=Image.open(BytesIO(input_image_data))
return True, image
return False, []
def search(mode, method, image, text):
translator = Translator(from_lang="chinese",to_lang="english")
text = translator.translate(text)
if mode == T2I:
dataset = torch.Tensor(encoder.encode(text)).unsqueeze(dim=0)
dataset_loader = DataLoader(dataset, batch_size=batch_size, num_workers=1, pin_memory=True, collate_fn=collate_fn_cap_padded)
caps_enc = list()
for i, (caps, length) in enumerate(dataset_loader, 0):
input_caps = caps
with torch.no_grad():
_, output_emb = join_emb(None, input_caps, length)
caps_enc.append(output_emb)
_stack = np.vstack(caps_enc)
elif mode == I2I:
dataset = normalize(torch.Tensor(image).permute(2, 0, 1)).unsqueeze(dim=0)
dataset_loader = DataLoader(dataset, batch_size=batch_size, num_workers=1, pin_memory=True, collate_fn=collate_fn_cap_padded)
img_enc = list()
for i, (imgs, length) in enumerate(dataset_loader, 0):
input_imgs = imgs
with torch.no_grad():
output_emb, _ = join_emb(input_imgs, None, None)
img_enc.append(output_emb)
_stack = np.vstack(img_enc)
recall_imgs = recallTopK(_stack, imgs_emb, imgs_url, ks=100)
res = []
idx = 0
tmp = []
swap_width = 5
if method == ViLT:
pass
else:
if method == DDT: swap_width = 5
elif method == UEFDT: swap_width = 3
elif method == IEFDT: swap_width = 2
tmp = recall_imgs[: swap_width]
recall_imgs[: swap_width] = recall_imgs[swap_width: swap_width * 2]
recall_imgs[swap_width: swap_width * 2] = tmp
for img_url in recall_imgs:
if idx == topK:
break
b, img = download_url_img(img_url)
if b:
res.append(img)
idx += 1
return res
if __name__ == "__main__":
import nltk
nltk.download('punkt')
# print("Loading model from:", model_path)
checkpoint = torch.load(model_path, map_location=lambda storage, loc: storage)
join_emb = joint_embedding(checkpoint['args_dict'])
join_emb.load_state_dict(checkpoint["state_dict"])
for param in join_emb.parameters():
param.requires_grad = False
join_emb.to(device)
join_emb.eval()
encoder = TextEncoder()
imgs_emb_file_path = "./coco_img_emb"
imgs_emb, imgs_path = load_obj(imgs_emb_file_path)
imgs_url = [os.path.join("http://images.cocodataset.org/train2017", img_path.strip().split('_')[-1]) for img_path in imgs_path]
normalize = transforms.Normalize(mean=[0.485 * 255, 0.456 * 255, 0.406 * 255], std=[0.229 * 255, 0.224 * 255, 0.225 * 255])
print("prepare done!")
iface = gr.Interface(
fn=search,
inputs=[
gr.inputs.Radio([I2I, T2I]),
gr.inputs.Radio([DDT, UEFDT, IEFDT, ViLT]),
gr.inputs.Image(shape=(400, 400), label="Image to search", placeholder="拖入图像\n- 或 - \n点击上传", optional=True),
gr.inputs.Textbox(
lines=1, label="Text query", placeholder="请输入待查询文本...",
),
],
theme="grass",
outputs=[
gr.outputs.Image(type="auto", label="1st Best match"),
gr.outputs.Image(type="auto", label="2nd Best match"),
gr.outputs.Image(type="auto", label="3rd Best match"),
gr.outputs.Image(type="auto", label="4rd Best match"),
gr.outputs.Image(type="auto", label="5rd Best match")
],
title="HUST毕业设计-图文检索系统",
description="请输入图片或文本,将为您展示相关的图片:",
)
iface.launch(share=False)
|