llava-4bit / app.py
merve's picture
merve HF staff
Update app.py
11e466e
raw
history blame
5.17 kB
import os
import string
import gradio as gr
import PIL.Image
import torch
from transformers import BitsAndBytesConfig, pipeline
import re
DESCRIPTION = "# LLaVA 🌋"
model_id = "llava-hf/llava-1.5-7b-hf"
quantization_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=torch.float16
)
pipe = pipeline("image-to-text", model=model_id, model_kwargs={"quantization_config": quantization_config})
def extract_response_pairs(text):
pattern = re.compile(r'(USER:.*?)ASSISTANT:(.*?)(?:$|USER:)', re.DOTALL)
matches = pattern.findall(text)
pairs = [(user.strip(), assistant.strip()) for user, assistant in matches]
return pairs
def postprocess_output(output: str) -> str:
if output and output[-1] not in string.punctuation:
output += "."
return output
def chat(image, text, temperature, length_penalty,
repetition_penalty, max_length, min_length, top_p,
history_chat):
prompt = " ".join(history_chat) + f"USER: <image>\n{text}\nASSISTANT:"
outputs = pipe(image, prompt=prompt,
generate_kwargs={"temperature":temperature,
"length_penalty":length_penalty,
"repetition_penalty":repetition_penalty,
"max_length":max_length,
"min_length":min_length,
"top_p":top_p})
history_chat.append(outputs[0]["generated_text"])
print(f"history_chat is {history_chat}")
chat_val = extract_response_pairs(" ".join(history_chat))
print(f"chat_val is {chat_val}")
return chat_val, history_chat
css = """
#mkd {
height: 500px;
overflow: auto;
border: 1px solid #ccc;
}
"""
with gr.Blocks(css="style.css") as demo:
gr.Markdown(DESCRIPTION)
gr.Markdown("## LLaVA, one of the greatest multimodal chat models is now available in transformers with 4-bit quantization! ⚡️")
gr.Markdown("## Try it 4-bit quantized LLaVA this demo 🤗")
chatbot = gr.Chatbot(label="Chat", show_label=False)
gr.Markdown("Input image and text and start chatting 👇")
with gr.Row():
image = gr.Image(type="pil")
text_input = gr.Text(label="Chat Input", show_label=False, max_lines=3, container=False)
history_chat = gr.State(value=[])
with gr.Row():
clear_chat_button = gr.Button("Clear")
chat_button = gr.Button("Submit", variant="primary")
with gr.Accordion(label="Advanced settings", open=False):
temperature = gr.Slider(
label="Temperature",
info="Used with nucleus sampling.",
minimum=0.5,
maximum=1.0,
step=0.1,
value=1.0,
)
length_penalty = gr.Slider(
label="Length Penalty",
info="Set to larger for longer sequence, used with beam search.",
minimum=-1.0,
maximum=2.0,
step=0.2,
value=1.0,
)
repetition_penalty = gr.Slider(
label="Repetition Penalty",
info="Larger value prevents repetition.",
minimum=1.0,
maximum=5.0,
step=0.5,
value=1.5,
)
max_length = gr.Slider(
label="Max Length",
minimum=1,
maximum=512,
step=1,
value=50,
)
min_length = gr.Slider(
label="Minimum Length",
minimum=1,
maximum=100,
step=1,
value=1,
)
top_p = gr.Slider(
label="Top P",
info="Used with nucleus sampling.",
minimum=0.5,
maximum=1.0,
step=0.1,
value=0.9,
)
chat_output = [
chatbot,
history_chat
]
chat_button.click(fn=chat, inputs=[image,
text_input,
temperature,
length_penalty,
repetition_penalty,
max_length,
min_length,
top_p,
history_chat],
outputs=chat_output,
api_name="Chat",
)
chat_inputs = [
image,
text_input,
temperature,
length_penalty,
repetition_penalty,
max_length,
min_length,
top_p,
history_chat
]
text_input.submit(
fn=chat,
inputs=chat_inputs,
outputs=chat_output
).success(
fn=lambda: "",
outputs=chat_inputs,
queue=False,
api_name=False,
)
clear_chat_button.click(
fn=lambda: ([], []),
inputs=None,
outputs=[
chatbot,
history_chat
],
queue=False,
api_name="clear",
)
image.change(
fn=lambda: ([], []),
inputs=None,
outputs=[
chatbot,
history_chat
],
queue=False)
examples = [["/content/baklava.png", "How to make this pastry?"],["/content/bee.png","Describe this image."]]
gr.Examples(examples=examples, inputs=[image, text_input, chat_inputs])
if __name__ == "__main__":
demo.queue(max_size=10).launch()