|
from io import BytesIO |
|
|
|
import string |
|
import gradio as gr |
|
import requests |
|
from utils import Endpoint |
|
|
|
|
|
def encode_image(image): |
|
buffered = BytesIO() |
|
image.save(buffered, format="JPEG") |
|
buffered.seek(0) |
|
|
|
return buffered |
|
|
|
|
|
def query_chat_api( |
|
image, prompt, decoding_method, temperature, len_penalty, repetition_penalty |
|
): |
|
|
|
url = endpoint.url |
|
|
|
headers = {"User-Agent": "BLIP-2 HuggingFace Space"} |
|
|
|
data = { |
|
"prompt": prompt, |
|
"use_nucleus_sampling": decoding_method == "Nucleus sampling", |
|
"temperature": temperature, |
|
"length_penalty": len_penalty, |
|
"repetition_penalty": repetition_penalty, |
|
} |
|
|
|
image = encode_image(image) |
|
files = {"image": image} |
|
|
|
response = requests.post(url, data=data, files=files, headers=headers) |
|
|
|
if response.status_code == 200: |
|
return response.json() |
|
else: |
|
return "Error: " + response.text |
|
|
|
|
|
def query_caption_api( |
|
image, decoding_method, temperature, len_penalty, repetition_penalty |
|
): |
|
|
|
url = endpoint.url |
|
|
|
url = url.replace("/generate", "/caption") |
|
|
|
headers = {"User-Agent": "BLIP-2 HuggingFace Space"} |
|
|
|
data = { |
|
"use_nucleus_sampling": decoding_method == "Nucleus sampling", |
|
"temperature": temperature, |
|
"length_penalty": len_penalty, |
|
"repetition_penalty": repetition_penalty, |
|
} |
|
|
|
image = encode_image(image) |
|
files = {"image": image} |
|
|
|
response = requests.post(url, data=data, files=files, headers=headers) |
|
|
|
if response.status_code == 200: |
|
return response.json() |
|
else: |
|
return "Error: " + response.text |
|
|
|
|
|
def postprocess_output(output): |
|
|
|
if not output[0][-1] in string.punctuation: |
|
output[0] += "." |
|
|
|
return output |
|
|
|
|
|
def inference_chat( |
|
image, |
|
text_input, |
|
decoding_method, |
|
temperature, |
|
length_penalty, |
|
repetition_penalty, |
|
history=[], |
|
): |
|
text_input = text_input |
|
history.append(text_input) |
|
|
|
prompt = " ".join(history) |
|
print(prompt) |
|
|
|
output = query_chat_api( |
|
image, prompt, decoding_method, temperature, length_penalty, repetition_penalty |
|
) |
|
output = postprocess_output(output) |
|
history += output |
|
|
|
chat = [ |
|
(history[i], history[i + 1]) for i in range(0, len(history) - 1, 2) |
|
] |
|
|
|
return {chatbot: chat, state: history} |
|
|
|
|
|
def inference_caption( |
|
image, |
|
decoding_method, |
|
temperature, |
|
length_penalty, |
|
repetition_penalty, |
|
): |
|
output = query_caption_api( |
|
image, decoding_method, temperature, length_penalty, repetition_penalty |
|
) |
|
|
|
return output[0] |
|
|
|
|
|
title = """<h1 align="center">BLIP-2</h1>""" |
|
description = """Gradio demo for BLIP-2, a multimodal chatbot from Salesforce Research. To use it, simply upload your image, or click one of the examples to load them. Please visit our <a href='https://github.com/salesforce/LAVIS/tree/main/projects/blip2' target='_blank'>project webpage</a>.</p> |
|
<p> <strong>Disclaimer</strong>: This is a research prototype and is not intended for production use. No data including but not restricted to text and images is collected. </p>""" |
|
article = "<p style='text-align: center'><a href='https://arxiv.org/abs/2201.12086' target='_blank'>BLIP-2: Bootstrapping Language-Image Pre-training with Frozen Image Encoders and Large Language Models</a>" |
|
|
|
endpoint = Endpoint() |
|
|
|
examples = [ |
|
["house.png", "How could someone get out of the house?"], |
|
|
|
|
|
|
|
|
|
] |
|
|
|
with gr.Blocks() as iface: |
|
state = gr.State([]) |
|
|
|
gr.Markdown(title) |
|
gr.Markdown(description) |
|
gr.Markdown(article) |
|
with gr.Row(): |
|
with gr.Column(): |
|
image_input = gr.Image(type="pil") |
|
|
|
with gr.Row(): |
|
sampling = gr.Radio( |
|
choices=["Beam search", "Nucleus sampling"], |
|
value="Beam search", |
|
label="Text Decoding Method", |
|
interactive=True, |
|
) |
|
|
|
temperature = gr.Slider( |
|
minimum=0.5, |
|
maximum=1.0, |
|
value=0.8, |
|
interactive=True, |
|
label="Temperature (set to 0 for greedy decoding under nucleus sampling)", |
|
) |
|
|
|
len_penalty = gr.Slider( |
|
minimum=-2.0, |
|
maximum=2.0, |
|
value=1.0, |
|
step=0.5, |
|
interactive=True, |
|
label="Length Penalty (larger value encourages longer sequence under beam search)", |
|
) |
|
|
|
rep_penalty = gr.Slider( |
|
minimum=1.0, |
|
maximum=5.0, |
|
value=1.5, |
|
step=0.5, |
|
interactive=True, |
|
label="Repeat Penalty (larger value prevents repetition)", |
|
) |
|
|
|
with gr.Row(): |
|
caption_output = gr.Textbox(lines=2, label="Caption Output") |
|
caption_button = gr.Button( |
|
value="Caption it!", interactive=True, variant="primary" |
|
) |
|
caption_button.click( |
|
inference_caption, |
|
[ |
|
image_input, |
|
sampling, |
|
temperature, |
|
len_penalty, |
|
rep_penalty, |
|
], |
|
[caption_output], |
|
) |
|
|
|
with gr.Column(): |
|
chat_input = gr.Textbox(lines=2, label="Chat Input") |
|
|
|
with gr.Row(): |
|
chatbot = gr.Chatbot() |
|
image_input.change(lambda: (None, "", "", []), [], [chatbot, chat_input, caption_output, state]) |
|
|
|
with gr.Row(): |
|
|
|
clear_button = gr.Button(value="Clear", interactive=True) |
|
clear_button.click( |
|
lambda: ("", None, [], []), |
|
[], |
|
[chat_input, image_input, chatbot, state], |
|
) |
|
|
|
submit_button = gr.Button( |
|
value="Submit", interactive=True, variant="primary" |
|
) |
|
submit_button.click( |
|
inference_chat, |
|
[ |
|
image_input, |
|
chat_input, |
|
sampling, |
|
temperature, |
|
len_penalty, |
|
rep_penalty, |
|
state, |
|
], |
|
[chatbot, state], |
|
) |
|
|
|
examples = gr.Examples( |
|
examples=examples, |
|
inputs=[image_input, chat_input], |
|
) |
|
|
|
iface.queue(concurrency_count=1, api_open=False, max_size=10) |
|
iface.launch(enable_queue=True) |
|
|