|
from functools import partial |
|
import json |
|
from multiprocessing.pool import ThreadPool as Pool |
|
import gradio as gr |
|
from utils import * |
|
|
|
from clip_retrieval.clip_client import ClipClient |
|
|
|
|
|
def text2image_gr(): |
|
def clip_api(query_text='', return_n=8, model_name=clip_base, thumbnail="是"): |
|
|
|
client = ClipClient(url="http://127.0.0.1:1234//knn-service", |
|
indice_name="ltr_cover_index", |
|
aesthetic_weight=0, |
|
num_images=int(return_n)) |
|
|
|
result = client.query(text=query_text) |
|
|
|
if not result or len(result) == 0: |
|
print("no result found") |
|
return None |
|
|
|
print(f"get result sucessed, num: {len(result)}") |
|
|
|
cover_urls = [res['cover_url'] for res in result] |
|
cover_info = [] |
|
for res in result: |
|
json_info = {"cover_url": res['cover_url'], |
|
"similarity": round(res['similarity'], 6), |
|
"docid": res['docids']} |
|
cover_info.append(str(json_info)) |
|
pool = Pool() |
|
new_url2image = partial(url2img, thumbnail=thumbnail) |
|
ret_imgs = pool.map(new_url2image, cover_urls) |
|
pool.close() |
|
pool.join() |
|
|
|
new_ret = [] |
|
for i in range(len(ret_imgs)): |
|
new_ret.append([ret_imgs[i], cover_info[i]]) |
|
return new_ret |
|
|
|
examples = [ |
|
["cat", 12, clip_base, "是"], |
|
["dog", 12, clip_base, "是"], |
|
["bag", 12, clip_base, "是"], |
|
["a cat is sit on the table", 12, clip_base, "是"] |
|
] |
|
|
|
title = "<h1 align='center'>CLIP文到图搜索应用</h1>" |
|
|
|
with gr.Blocks() as demo: |
|
gr.Markdown(title) |
|
gr.Markdown(description) |
|
with gr.Row(): |
|
with gr.Column(scale=1): |
|
with gr.Column(scale=2): |
|
text = gr.Textbox(value="cat", label="请填写文本", elem_id=0, interactive=True) |
|
num = gr.components.Slider(minimum=0, maximum=50, step=1, value=8, label="返回图片数(可能被过滤部分)", elem_id=2) |
|
model = gr.components.Radio(label="模型选择", choices=[clip_base], |
|
value=clip_base, elem_id=3) |
|
thumbnail = gr.components.Radio(label="是否返回缩略图", choices=[yes, no], |
|
value=yes, elem_id=4) |
|
btn = gr.Button("搜索", ) |
|
with gr.Column(scale=100): |
|
out = gr.Gallery(label="检索结果为:", columns=4, height="auto") |
|
inputs = [text, num, model, thumbnail] |
|
btn.click(fn=clip_api, inputs=inputs, outputs=out) |
|
gr.Examples(examples, inputs=inputs) |
|
return demo |
|
|
|
if __name__ == "__main__": |
|
gr.close_all() |
|
with gr.TabbedInterface( |
|
[text2image_gr()], |
|
["文到图搜索"], |
|
) as demo: |
|
demo.launch(server_name='127.0.0.1', share=False) |
|
|