BLIP2 / app.py
Dongxu Li
add headers
30474d6
raw history blame
No virus
3.32 kB
from PIL import Image
import requests
import json
import gradio as gr
from io import BytesIO
def encode_image(image):
buffered = BytesIO()
image.save(buffered, format="JPEG")
buffered.seek(0)
return buffered
def query_api(image, prompt, decoding_method):
# local host for testing
url = "http://34.132.142.70:5000/api/generate"
headers = {
'User-Agent': 'BLIP-2 HuggingFace Space'
}
data = {"prompt": prompt, "use_nucleus_sampling": decoding_method == "Nucleus sampling"}
image = encode_image(image)
files = {"image": image}
response = requests.post(url, data=data, files=files, headers=headers)
if response.status_code == 200:
return response.json()
else:
return "Error: " + response.text
def prepend_question(text):
text = text.strip().lower()
return "question: " + text
def prepend_answer(text):
text = text.strip().lower()
return "answer: " + text
def get_prompt_from_history(history):
prompts = []
for i in range(len(history)):
if i % 2 == 0:
prompts.append(prepend_question(history[i]))
else:
prompts.append(prepend_answer(history[i]))
return "\n".join(prompts)
def postp_answer(text):
if text.startswith("answer: "):
return text[8:]
elif text.startswith("a: "):
return text[2:]
else:
return text
def prep_question(text):
if text.startswith("question: "):
text = text[10:]
elif text.startswith("q: "):
text = text[2:]
if not text.endswith("?"):
text += "?"
return text
def inference(image, text_input, decoding_method, history=[]):
text_input = prep_question(text_input)
history.append(text_input)
# prompt = '\n'.join(history)
prompt = get_prompt_from_history(history)
# print("prompt: " + prompt)
output = query_api(image, prompt, decoding_method)
output = [postp_answer(output[0])]
history += output
chat = [(history[i], history[i+1]) for i in range(0, len(history)-1, 2)] # convert to tuples of list
return chat, history
inputs = [gr.inputs.Image(type='pil'),
gr.inputs.Textbox(lines=2, label="Text input"),
gr.inputs.Radio(choices=['Nucleus sampling','Beam search'], type="value", default="Nucleus sampling", label="Text Decoding Method"),
"state",
]
outputs = ["chatbot", "state"]
title = "BLIP-2"
description = """Gradio demo for BLIP-2, a multimodal chatbot from Salesforce Research. To use it, simply upload your image, or click one of the examples to load them. Please visit our <a href='https://github.com/salesforce/LAVIS/tree/main/projects/blip2' target='_blank'>project webpage</a>.</p>
<p> <strong>Disclaimer</strong>: This is a research prototype and is not intended for production use. No data including but not restricted to text and images is collected. </p>"""
article = "<p style='text-align: center'><a href='https://arxiv.org/abs/2201.12086' target='_blank'>BLIP-2: Bootstrapping Language-Image Pre-training with Frozen Image Encoders and Large Language Models</a>"
iface = gr.Interface(inference, inputs, outputs, title=title, description=description, article=article)
iface.launch(enable_queue=True)