BLIP2 / app.py
root
update ui
21e0a1c
raw
history blame
7.02 kB
from io import BytesIO
import string
import gradio as gr
import requests
from utils import Endpoint
def encode_image(image):
buffered = BytesIO()
image.save(buffered, format="JPEG")
buffered.seek(0)
return buffered
def query_chat_api(
image, prompt, decoding_method, temperature, len_penalty, repetition_penalty
):
url = endpoint.url
headers = {"User-Agent": "BLIP-2 HuggingFace Space"}
data = {
"prompt": prompt,
"use_nucleus_sampling": decoding_method == "Nucleus sampling",
"temperature": temperature,
"length_penalty": len_penalty,
"repetition_penalty": repetition_penalty,
}
image = encode_image(image)
files = {"image": image}
response = requests.post(url, data=data, files=files, headers=headers)
if response.status_code == 200:
return response.json()
else:
return "Error: " + response.text
def query_caption_api(
image, decoding_method, temperature, len_penalty, repetition_penalty
):
url = endpoint.url
# replace /generate with /caption
url = url.replace("/generate", "/caption")
headers = {"User-Agent": "BLIP-2 HuggingFace Space"}
data = {
"use_nucleus_sampling": decoding_method == "Nucleus sampling",
"temperature": temperature,
"length_penalty": len_penalty,
"repetition_penalty": repetition_penalty,
}
image = encode_image(image)
files = {"image": image}
response = requests.post(url, data=data, files=files, headers=headers)
if response.status_code == 200:
return response.json()
else:
return "Error: " + response.text
def postprocess_output(output):
# if last character is not a punctuation, add a full stop
if not output[0][-1] in string.punctuation:
output[0] += "."
return output
def inference_chat(
image,
text_input,
decoding_method,
temperature,
length_penalty,
repetition_penalty,
history=[],
):
text_input = text_input
history.append(text_input)
prompt = " ".join(history)
print(prompt)
output = query_chat_api(
image, prompt, decoding_method, temperature, length_penalty, repetition_penalty
)
output = postprocess_output(output)
history += output
chat = [
(history[i], history[i + 1]) for i in range(0, len(history) - 1, 2)
] # convert to tuples of list
return {chatbot: chat, state: history}
def inference_caption(
image,
decoding_method,
temperature,
length_penalty,
repetition_penalty,
):
output = query_caption_api(
image, decoding_method, temperature, length_penalty, repetition_penalty
)
return output[0]
title = """<h1 align="center">BLIP-2</h1>"""
description = """Gradio demo for BLIP-2, a multimodal chatbot from Salesforce Research. To use it, simply upload your image, or click one of the examples to load them. Please visit our <a href='https://github.com/salesforce/LAVIS/tree/main/projects/blip2' target='_blank'>project webpage</a>.</p>
<p> <strong>Disclaimer</strong>: This is a research prototype and is not intended for production use. No data including but not restricted to text and images is collected. </p>"""
article = "<p style='text-align: center'><a href='https://arxiv.org/abs/2201.12086' target='_blank'>BLIP-2: Bootstrapping Language-Image Pre-training with Frozen Image Encoders and Large Language Models</a>"
endpoint = Endpoint()
examples = [
["house.png", "How could someone get out of the house?"],
# [
# "sunset.png",
# "Write a romantic message that goes along this photo.",
# ],
]
with gr.Blocks() as iface:
state = gr.State([])
gr.Markdown(title)
gr.Markdown(description)
gr.Markdown(article)
with gr.Row():
with gr.Column():
image_input = gr.Image(type="pil")
with gr.Row():
sampling = gr.Radio(
choices=["Beam search", "Nucleus sampling"],
value="Beam search",
label="Text Decoding Method",
interactive=True,
)
temperature = gr.Slider(
minimum=0.5,
maximum=1.0,
value=0.8,
interactive=True,
label="Temperature (set to 0 for greedy decoding under nucleus sampling)",
)
len_penalty = gr.Slider(
minimum=-2.0,
maximum=2.0,
value=1.0,
step=0.5,
interactive=True,
label="Length Penalty (larger value encourages longer sequence under beam search)",
)
rep_penalty = gr.Slider(
minimum=1.0,
maximum=5.0,
value=1.5,
step=0.5,
interactive=True,
label="Repeat Penalty (larger value prevents repetition)",
)
with gr.Row():
caption_output = gr.Textbox(lines=2, label="Caption Output")
caption_button = gr.Button(
value="Caption it!", interactive=True, variant="primary"
)
caption_button.click(
inference_caption,
[
image_input,
sampling,
temperature,
len_penalty,
rep_penalty,
],
[caption_output],
)
with gr.Column():
chat_input = gr.Textbox(lines=2, label="Chat Input")
with gr.Row():
chatbot = gr.Chatbot()
image_input.change(lambda: (None, "", "", []), [], [chatbot, chat_input, caption_output, state])
with gr.Row():
clear_button = gr.Button(value="Clear", interactive=True)
clear_button.click(
lambda: ("", None, [], []),
[],
[chat_input, image_input, chatbot, state],
)
submit_button = gr.Button(
value="Submit", interactive=True, variant="primary"
)
submit_button.click(
inference_chat,
[
image_input,
chat_input,
sampling,
temperature,
len_penalty,
rep_penalty,
state,
],
[chatbot, state],
)
examples = gr.Examples(
examples=examples,
inputs=[image_input, chat_input],
)
iface.queue(concurrency_count=1, api_open=False, max_size=10)
iface.launch(enable_queue=True)