Spaces:
Runtime error
Runtime error
import string | |
import gradio as gr | |
import requests | |
import torch | |
from transformers import BlipForQuestionAnswering, BlipProcessor | |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
processor = BlipProcessor.from_pretrained("Salesforce/blip-vqa-large") | |
model_vqa = BlipForQuestionAnswering.from_pretrained("Salesforce/blip-vqa-large").to(device) | |
def inference_chat(input_image,input_text): | |
inputs = processor(images=input_image, text=input_text,return_tensors="pt") | |
inputs["max_length"] = 20 | |
inputs["num_beams"] = 5 | |
out = model_vqa.generate(**inputs) | |
return processor.batch_decode(out, skip_special_tokens=True)[0] | |
with gr.Blocks( | |
css=""" | |
.message.svelte-w6rprc.svelte-w6rprc.svelte-w6rprc {font-size: 20px; margin-top: 20px} | |
#component-21 > div.wrap.svelte-w6rprc {height: 600px;} | |
""" | |
) as iface: | |
state = gr.State([]) | |
#caption_output = None | |
#gr.Markdown(title) | |
#gr.Markdown(description) | |
#gr.Markdown(article) | |
with gr.Row(): | |
with gr.Column(scale=1): | |
image_input = gr.Image(type="pil") | |
with gr.Row(): | |
with gr.Column(scale=1): | |
caption_output = None | |
chat_input = gr.Textbox(lines=1, label="VQA Input") | |
chat_input.submit( | |
inference_chat, | |
[ | |
image_input, | |
chat_input, | |
], | |
[ caption_output], | |
) | |
with gr.Row(): | |
clear_button = gr.Button(value="Clear", interactive=True) | |
clear_button.click( | |
lambda: ("", [], []), | |
[], | |
[chat_input, state], | |
queue=False, | |
) | |
submit_button = gr.Button( | |
value="Submit", interactive=True, variant="primary" | |
) | |
submit_button.click( | |
inference_chat, | |
[ | |
image_input, | |
chat_input, | |
], | |
[caption_output], | |
) | |
caption_output = gr.Textbox(lines=1, label="VQA Output") | |
image_input.change( | |
lambda: ("", "", []), | |
[], | |
[ caption_output, state], | |
queue=False, | |
) | |
# examples = gr.Examples( | |
# examples=examples, | |
# inputs=[image_input, chat_input], | |
# ) | |
iface.queue(concurrency_count=1, api_open=False, max_size=10) | |
iface.launch(enable_queue=True) |