vqa_blip_large / app.py
xxx1's picture
Create app.py
eb65419
raw
history blame
2.91 kB
import string
import gradio as gr
import requests
import torch
from transformers import BlipForQuestionAnswering, BlipProcessor
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
processor = BlipProcessor.from_pretrained("Salesforce/blip-vqa-large")
model_vqa = BlipForQuestionAnswering.from_pretrained("Salesforce/blip-vqa-large").to(device)
def inference_chat(input_image,input_text):
inputs = processor(images=input_image, text=input_text,return_tensors="pt")
inputs["max_length"] = 20
inputs["num_beams"] = 5
out = model_vqa.generate(**inputs)
return processor.batch_decode(out, skip_special_tokens=True)[0]
with gr.Blocks(
css="""
.message.svelte-w6rprc.svelte-w6rprc.svelte-w6rprc {font-size: 20px; margin-top: 20px}
#component-21 > div.wrap.svelte-w6rprc {height: 600px;}
"""
) as iface:
state = gr.State([])
#caption_output = None
#gr.Markdown(title)
#gr.Markdown(description)
#gr.Markdown(article)
with gr.Row():
with gr.Column(scale=1):
image_input = gr.Image(type="pil")
with gr.Row():
with gr.Column(scale=1):
caption_output = None
chat_input = gr.Textbox(lines=1, label="VQA Input")
chat_input.submit(
inference_chat,
[
image_input,
chat_input,
],
[ caption_output],
)
with gr.Row():
clear_button = gr.Button(value="Clear", interactive=True)
clear_button.click(
lambda: ("", [], []),
[],
[chat_input, state],
queue=False,
)
submit_button = gr.Button(
value="Submit", interactive=True, variant="primary"
)
submit_button.click(
inference_chat,
[
image_input,
chat_input,
],
[caption_output],
)
caption_output = gr.Textbox(lines=1, label="VQA Output")
image_input.change(
lambda: ("", "", []),
[],
[ caption_output, state],
queue=False,
)
# examples = gr.Examples(
# examples=examples,
# inputs=[image_input, chat_input],
# )
iface.queue(concurrency_count=1, api_open=False, max_size=10)
iface.launch(enable_queue=True)