OtterHD-Demo / app.py
luodian's picture
update
1d62154
raw
history blame
5.18 kB
import os
import datetime
import json
import base64
from PIL import Image
import gradio as gr
import hashlib
import requests
from utils import build_logger
import io
LOGDIR = "log"
logger = build_logger("otter", LOGDIR)
# no_change_btn = gr.Button.update()
# enable_btn = gr.Button.update(interactive=True)
# disable_btn = gr.Button.update(interactive=False)
def decode_image(encoded_image: str) -> Image:
decoded_bytes = base64.b64decode(encoded_image.encode("utf-8"))
buffer = io.BytesIO(decoded_bytes)
image = Image.open(buffer)
return image
def encode_image(image: Image.Image, format: str = "PNG") -> str:
with io.BytesIO() as buffer:
image.save(buffer, format=format)
encoded_image = base64.b64encode(buffer.getvalue()).decode("utf-8")
return encoded_image
def get_conv_log_filename():
t = datetime.datetime.now()
name = os.path.join(LOGDIR, f"{t.year}-{t.month:02d}-{t.day:02d}-conv.json")
return name
def get_conv_image_dir():
name = os.path.join(LOGDIR, "images")
os.makedirs(name, exist_ok=True)
return name
def get_image_name(image, image_dir=None):
buffer = io.BytesIO()
image.save(buffer, format="PNG")
image_bytes = buffer.getvalue()
md5 = hashlib.md5(image_bytes).hexdigest()
if image_dir is not None:
image_name = os.path.join(image_dir, md5 + ".png")
else:
image_name = md5 + ".png"
return image_name
def resize_image(image, max_size):
width, height = image.size
aspect_ratio = float(width) / float(height)
if width > height:
new_width = max_size
new_height = int(new_width / aspect_ratio)
else:
new_height = max_size
new_width = int(new_height * aspect_ratio)
resized_image = image.resize((new_width, new_height))
return resized_image
def http_bot(image_input, text_input, request: gr.Request):
logger.info(f"http_bot. ip: {request.client.host}")
print(f"Prompt request: {text_input}")
base64_image_str = encode_image(image_input)
payload = {
"content": [
{
"prompt": text_input,
"image": base64_image_str,
}
],
"token": "sk-OtterHD",
}
print(
"request: ",
{
"prompt": text_input,
"image": base64_image_str[:10],
},
)
url = "https://ensures-picture-choices-labels.trycloudflare.com/app/otter"
headers = {"Content-Type": "application/json"}
response = requests.post(url, headers=headers, data=json.dumps(payload))
results = response.json()
print("response: ", {"result": results["result"]})
return results["result"]
title = """
# OTTER-HD: A High-Resolution Multi-modality Model
[[Otter Codebase]](https://github.com/Luodian/Otter) [[Paper]]() [[Checkpoints & Benchmarks]](https://huggingface.co/Otter-AI)
**Tips**:
- Since 1024x1024 images are large that may cause the longer transmit time from HF Space to our backend server. Please be kinda patient for the response.
- The model is currently mainly focus on high-res image resolution and need to be futher improved on (1) hallucination reduction (2) text formatting control and some more you can suggest.
"""
css = """
#mkd {
height: 1000px;
overflow: auto;
border: 1px solid #ccc;
}
"""
if __name__ == "__main__":
with gr.Blocks(css=css) as demo:
gr.Markdown(title)
dialog_state = gr.State()
input_state = gr.State()
with gr.Tab("Ask a Question"):
with gr.Row(equal_height=True):
with gr.Column(scale=2):
image_input = gr.Image(label="Upload a High-Res Image", type="pil")
with gr.Column(scale=1):
vqa_output = gr.Textbox(label="Output")
text_input = gr.Textbox(label="Ask a Question")
vqa_btn = gr.Button("Send It")
gr.Examples(
[
[
"./assets/IMG_00095.png",
"How many camels are inside this image?",
],
[
"./assets/IMG_00095.png",
"How many people are inside this image?",
],
[
"./assets/IMG_00012.png",
"How many apples are there?",
],
# ["./assets/./IMG_00012.png", "How many apples are there? Count them row by row."],
[
"./assets/IMG_00080.png",
"What is this and where is it from?",
],
[
"./assets/IMG_00094.png",
"What's important on this website?",
],
],
inputs=[image_input, text_input],
outputs=[vqa_output],
fn=http_bot,
label="Click on any Examples below👇",
)
vqa_btn.click(fn=http_bot, inputs=[image_input, text_input], outputs=vqa_output)
demo.launch()