|
from PIL import Image |
|
import gradio as gr |
|
import torch |
|
import requests |
|
from models.zhclip import ZhCLIPProcessor, ZhCLIPModel |
|
|
|
version = 'nlpcver/zh-clip-vit-roberta-large-patch14' |
|
model = ZhCLIPModel.from_pretrained(version) |
|
processor = ZhCLIPProcessor.from_pretrained(version) |
|
|
|
|
|
|
|
def get_result(image,text): |
|
inputs = processor(text=[text], images=image, return_tensors="pt", padding=True) |
|
|
|
outputs = model(**inputs) |
|
image_features = outputs.image_features |
|
text_features = outputs.text_features |
|
text_probs = (image_features @ text_features.T).softmax(dim=-1) |
|
return text_probs |
|
with gr.Blocks( |
|
css=""" |
|
.message.svelte-w6rprc.svelte-w6rprc.svelte-w6rprc {font-size: 20px; margin-top: 20px} |
|
#component-21 > div.wrap.svelte-w6rprc {height: 600px;} |
|
""" |
|
) as iface: |
|
state = gr.State([]) |
|
|
|
with gr.Row(): |
|
with gr.Column(scale=1): |
|
image_input = gr.Image(type="pil",label="VQA Image Input") |
|
with gr.Row(): |
|
with gr.Column(scale=1): |
|
chat_input = gr.Textbox(lines=1, label="VQA Question Input") |
|
with gr.Row(): |
|
clear_button = gr.Button(value="Clear", interactive=True,width=30) |
|
submit_button = gr.Button( |
|
value="Submit", interactive=True, variant="primary" |
|
) |
|
''' |
|
cap_submit_button = gr.Button( |
|
value="Submit_CAP", interactive=True, variant="primary" |
|
) |
|
gpt3_submit_button = gr.Button( |
|
value="Submit_GPT3", interactive=True, variant="primary" |
|
) |
|
''' |
|
with gr.Column(): |
|
caption_output = gr.Textbox(lines=0, label="ITM") |
|
|
|
chat_input.submit( |
|
get_result, |
|
[ |
|
image_input, |
|
chat_input, |
|
], |
|
[ caption_output], |
|
) |
|
clear_button.click( |
|
lambda: ("", [],"","",""), |
|
[], |
|
[chat_input, state,caption_output], |
|
queue=False, |
|
) |
|
submit_button.click( |
|
get_result, |
|
[ |
|
image_input, |
|
chat_input, |
|
], |
|
[caption_output], |
|
) |
|
iface.queue(concurrency_count=1, api_open=False, max_size=10) |
|
iface.launch(enable_queue=True) |
|
|