|
from functools import partial |
|
import json |
|
from multiprocessing.pool import ThreadPool as Pool |
|
import gradio as gr |
|
import PIL |
|
from PIL import Image |
|
from utils import * |
|
|
|
from clip_retrieval.clip_client import ClipClient |
|
|
|
def image2text_gr(): |
|
def clip_api(query_image=None, return_n=8, model_name=clip_base, thumbnail=yes): |
|
|
|
client = ClipClient(url="http://127.0.0.1:1234//knn-service", |
|
indice_name="ltr_cover_index", |
|
aesthetic_weight=0, |
|
num_images=int(return_n)) |
|
result = client.query(image=query_image) |
|
|
|
if not result or len(result) == 0: |
|
print("no result found") |
|
return None |
|
|
|
print(f"get result sucessed, num: {len(result)}") |
|
|
|
cover_urls = [res['cover_url'] for res in result] |
|
cover_info = [] |
|
for res in result: |
|
json_info = {"cover_url": res['cover_url'], |
|
"similarity": round(res['similarity'], 6), |
|
"docid": res['docids']} |
|
cover_info.append(str(json_info)) |
|
pool = Pool() |
|
new_url2image = partial(url2img, thumbnail=thumbnail) |
|
ret_imgs = pool.map(new_url2image, cover_urls) |
|
pool.close() |
|
pool.join() |
|
|
|
new_ret = [] |
|
for i in range(len(ret_imgs)): |
|
new_ret.append([ret_imgs[i], cover_info[i]]) |
|
return new_ret |
|
|
|
examples = [ |
|
["https://xingchen-data.oss-cn-zhangjiakou.aliyuncs.com/coco/2014/test2014/COCO_test2014_000000000069.jpg", 20, |
|
clip_base, "是"], |
|
["https://xingchen-data.oss-cn-zhangjiakou.aliyuncs.com/coco/2014/test2014/COCO_test2014_000000000080.jpg", 20, |
|
clip_base, "是"], |
|
["https://xingchen-data.oss-cn-zhangjiakou.aliyuncs.com/coco/2014/train2014/COCO_train2014_000000000009.jpg", |
|
20, clip_base, "是"], |
|
["https://xingchen-data.oss-cn-zhangjiakou.aliyuncs.com/coco/2014/train2014/COCO_train2014_000000000308.jpg", |
|
20, clip_base, "是"] |
|
] |
|
|
|
title = "<h1 align='center'>CLIP图到图搜索应用</h1>" |
|
|
|
with gr.Blocks() as demo: |
|
gr.Markdown(title) |
|
gr.Markdown(description) |
|
with gr.Row(): |
|
with gr.Column(scale=1): |
|
with gr.Column(scale=2): |
|
img = gr.Textbox(value="https://xingchen-data.oss-cn-zhangjiakou.aliyuncs.com/coco/2014/test2014/COCO_test2014_000000000069.jpg", label="图片地址", elem_id=0, interactive=True) |
|
num = gr.components.Slider(minimum=0, maximum=50, step=1, value=8, label="返回图片数(可能被过滤部分)", elem_id=2) |
|
model = gr.components.Radio(label="模型选择", choices=[clip_base], |
|
value=clip_base, elem_id=3) |
|
tn = gr.components.Radio(label="是否返回缩略图", choices=[yes, no], |
|
value=yes, elem_id=4) |
|
btn = gr.Button("搜索", ) |
|
with gr.Column(scale=100): |
|
out = gr.Gallery(label="检索结果为:", columns=4, height="auto") |
|
inputs = [img, num, model, tn] |
|
btn.click(fn=clip_api, inputs=inputs, outputs=out) |
|
gr.Examples(examples, inputs=inputs) |
|
return demo |
|
|
|
|
|
if __name__ == "__main__": |
|
with gr.TabbedInterface( |
|
[image2text_gr()], |
|
["图到图搜索"], |
|
) as demo: |
|
demo.launch( |
|
|
|
server_name='127.0.0.1', |
|
share=False |
|
) |
|
|