zh-clip / app.py
xxx1's picture
Create app.py
949552b
raw
history blame
2.73 kB
from PIL import Image
import gradio as gr
import torch
import requests
from models.zhclip import ZhCLIPProcessor, ZhCLIPModel # From https://www.github.com/yue-gang/ZH-CLIP
version = 'nlpcver/zh-clip-vit-roberta-large-patch14'
model = ZhCLIPModel.from_pretrained(version)
processor = ZhCLIPProcessor.from_pretrained(version)
def get_result(image,text):
inputs = processor(text=[text], images=image, return_tensors="pt", padding=True)
outputs = model(**inputs)
image_features = outputs.image_features
text_features = outputs.text_features
text_probs = (image_features @ text_features.T).softmax(dim=-1)
return text_probs
with gr.Blocks(
css="""
.message.svelte-w6rprc.svelte-w6rprc.svelte-w6rprc {font-size: 20px; margin-top: 20px}
#component-21 > div.wrap.svelte-w6rprc {height: 600px;}
"""
) as iface:
state = gr.State([])
with gr.Row():
with gr.Column(scale=1):
image_input = gr.Image(type="pil",label="VQA Image Input")
with gr.Row():
with gr.Column(scale=1):
chat_input = gr.Textbox(lines=1, label="VQA Question Input")
with gr.Row():
clear_button = gr.Button(value="Clear", interactive=True,width=30)
submit_button = gr.Button(
value="Submit", interactive=True, variant="primary"
)
'''
cap_submit_button = gr.Button(
value="Submit_CAP", interactive=True, variant="primary"
)
gpt3_submit_button = gr.Button(
value="Submit_GPT3", interactive=True, variant="primary"
)
'''
with gr.Column():
caption_output = gr.Textbox(lines=0, label="ITM")
chat_input.submit(
get_result,
[
image_input,
chat_input,
],
[ caption_output],
)
clear_button.click(
lambda: ("", [],"","",""),
[],
[chat_input, state,caption_output],
queue=False,
)
submit_button.click(
get_result,
[
image_input,
chat_input,
],
[caption_output],
)
iface.queue(concurrency_count=1, api_open=False, max_size=10)
iface.launch(enable_queue=True)