Spaces:
Running
Running
from io import BytesIO | |
import string | |
import gradio as gr | |
import requests | |
from utils import Endpoint, get_token | |
def encode_image(image): | |
buffered = BytesIO() | |
image.save(buffered, format="JPEG") | |
buffered.seek(0) | |
return buffered | |
def query_chat_api( | |
image, prompt, decoding_method, temperature, len_penalty, repetition_penalty | |
): | |
url = endpoint.url | |
headers = { | |
"User-Agent": "BLIP-2 HuggingFace Space", | |
"Auth-Token": get_token(), | |
} | |
data = { | |
"prompt": prompt, | |
"use_nucleus_sampling": decoding_method == "Nucleus sampling", | |
"temperature": temperature, | |
"length_penalty": len_penalty, | |
"repetition_penalty": repetition_penalty, | |
} | |
image = encode_image(image) | |
files = {"image": image} | |
response = requests.post(url, data=data, files=files, headers=headers) | |
if response.status_code == 200: | |
return response.json() | |
else: | |
return "Error: " + response.text | |
def query_caption_api( | |
image, decoding_method, temperature, len_penalty, repetition_penalty | |
): | |
url = endpoint.url | |
# replace /generate with /caption | |
url = url.replace("/generate", "/caption") | |
headers = { | |
"User-Agent": "BLIP-2 HuggingFace Space", | |
"Auth-Token": get_token(), | |
} | |
data = { | |
"use_nucleus_sampling": decoding_method == "Nucleus sampling", | |
"temperature": temperature, | |
"length_penalty": len_penalty, | |
"repetition_penalty": repetition_penalty, | |
} | |
image = encode_image(image) | |
files = {"image": image} | |
response = requests.post(url, data=data, files=files, headers=headers) | |
if response.status_code == 200: | |
return response.json() | |
else: | |
return "Error: " + response.text | |
def postprocess_output(output): | |
# if last character is not a punctuation, add a full stop | |
if not output[0][-1] in string.punctuation: | |
output[0] += "." | |
return output | |
def inference_chat( | |
image, | |
text_input, | |
decoding_method, | |
temperature, | |
length_penalty, | |
repetition_penalty, | |
history=[], | |
): | |
text_input = text_input | |
history.append(text_input) | |
prompt = " ".join(history) | |
print(prompt) | |
output = query_chat_api( | |
image, prompt, decoding_method, temperature, length_penalty, repetition_penalty | |
) | |
output = postprocess_output(output) | |
history += output | |
chat = [ | |
(history[i], history[i + 1]) for i in range(0, len(history) - 1, 2) | |
] # convert to tuples of list | |
return {chatbot: chat, state: history} | |
def inference_caption( | |
image, | |
decoding_method, | |
temperature, | |
length_penalty, | |
repetition_penalty, | |
): | |
output = query_caption_api( | |
image, decoding_method, temperature, length_penalty, repetition_penalty | |
) | |
return output[0] | |
title = """<h1 align="center">BLIP-2</h1>""" | |
description = """Gradio demo for BLIP-2, image-to-text generation from Salesforce Research. To use it, simply upload your image, or click one of the examples to load them. Please visit our <a href='https://github.com/salesforce/LAVIS/tree/main/projects/blip2' target='_blank'>project webpage</a>.</p> | |
<p> <strong>Disclaimer</strong>: This is a research prototype and is not intended for production use. No data including but not restricted to text and images is collected. </p>""" | |
article = "<p style='text-align: center'><a href='https://arxiv.org/abs/2301.12597' target='_blank'>BLIP-2: Bootstrapping Language-Image Pre-training with Frozen Image Encoders and Large Language Models</a>" | |
endpoint = Endpoint() | |
examples = [ | |
["house.png", "How could someone get out of the house?"], | |
["flower.jpg", "Question: What is this flower and where is it's origin? Answer:"], | |
["forbidden_city.webp", "In what dynasties was this place build?"], | |
# [ | |
# "sunset.png", | |
# "Write a romantic message that goes along this photo.", | |
# ], | |
] | |
with gr.Blocks() as iface: | |
state = gr.State([]) | |
gr.Markdown(title) | |
gr.Markdown(description) | |
gr.Markdown(article) | |
with gr.Row(): | |
with gr.Column(): | |
image_input = gr.Image(type="pil") | |
with gr.Row(): | |
sampling = gr.Radio( | |
choices=["Beam search", "Nucleus sampling"], | |
value="Beam search", | |
label="Text Decoding Method", | |
interactive=True, | |
) | |
temperature = gr.Slider( | |
minimum=0.5, | |
maximum=1.0, | |
value=0.8, | |
step=0.1, | |
interactive=True, | |
label="Temperature (used with nucleus sampling)", | |
) | |
len_penalty = gr.Slider( | |
minimum=-1.0, | |
maximum=2.0, | |
value=1.0, | |
step=0.2, | |
interactive=True, | |
label="Length Penalty (set to larger for longer sequence, used with beam search)", | |
) | |
rep_penalty = gr.Slider( | |
minimum=1.0, | |
maximum=5.0, | |
value=1.5, | |
step=0.5, | |
interactive=True, | |
label="Repeat Penalty (larger value prevents repetition)", | |
) | |
with gr.Row(): | |
caption_output = gr.Textbox(lines=2, label="Caption Output (from OPT)") | |
caption_button = gr.Button( | |
value="Caption it!", interactive=True, variant="primary" | |
) | |
caption_button.click( | |
inference_caption, | |
[ | |
image_input, | |
sampling, | |
temperature, | |
len_penalty, | |
rep_penalty, | |
], | |
[caption_output], | |
) | |
with gr.Column(): | |
chat_input = gr.Textbox(lines=2, label="Chat Input") | |
with gr.Row(): | |
chatbot = gr.Chatbot(label="Chat Output (from FlanT5)") | |
image_input.change(lambda: (None, "", "", []), [], [chatbot, chat_input, caption_output, state]) | |
with gr.Row(): | |
clear_button = gr.Button(value="Clear", interactive=True) | |
clear_button.click( | |
lambda: ("", None, [], []), | |
[], | |
[chat_input, image_input, chatbot, state], | |
) | |
submit_button = gr.Button( | |
value="Submit", interactive=True, variant="primary" | |
) | |
submit_button.click( | |
inference_chat, | |
[ | |
image_input, | |
chat_input, | |
sampling, | |
temperature, | |
len_penalty, | |
rep_penalty, | |
state, | |
], | |
[chatbot, state], | |
) | |
examples = gr.Examples( | |
examples=examples, | |
inputs=[image_input, chat_input], | |
# outputs=[chatbot, state], | |
# run_on_click=True, | |
# fn = inference_chat, | |
) | |
iface.queue(concurrency_count=1, api_open=False, max_size=10) | |
iface.launch(enable_queue=True) | |