|
import gradio as gr |
|
import spaces |
|
import time |
|
from PIL import Image |
|
from models.mllava import MLlavaProcessor, LlavaForConditionalGeneration, chat_mllava, MLlavaForConditionalGeneration |
|
from typing import List |
|
processor = MLlavaProcessor.from_pretrained("MFuyu/mllava_llava_debug_nlvr2_v5_4096") |
|
model = LlavaForConditionalGeneration.from_pretrained("MFuyu/mllava_llava_debug_nlvr2_v5_4096") |
|
|
|
@spaces.GPU |
|
def generate(text:str, images:List[Image.Image], history: List[dict], **kwargs): |
|
global processor, model |
|
model = model.to("cuda") |
|
if not images: |
|
images = None |
|
for text, history in chat_mllava(text, images, model, processor, history=history, stream=True, **kwargs): |
|
yield text |
|
|
|
return text |
|
|
|
def enable_next_image(uploaded_images, image): |
|
uploaded_images.append(image) |
|
return uploaded_images, gr.MultimodalTextbox(value=None, interactive=False) |
|
|
|
def add_message(history, message): |
|
if message["files"]: |
|
for file in message["files"]: |
|
history.append([(file,), None]) |
|
if message["text"]: |
|
history.append([message["text"], None]) |
|
return history, gr.MultimodalTextbox(value=None) |
|
|
|
def print_like_dislike(x: gr.LikeData): |
|
print(x.index, x.value, x.liked) |
|
|
|
|
|
def get_chat_history(history): |
|
chat_history = [] |
|
for i, message in enumerate(history): |
|
if isinstance(message[0], str): |
|
chat_history.append({"role": "user", "text": message[0]}) |
|
if i != len(history) - 1: |
|
assert message[1], "The bot message is not provided, internal error" |
|
chat_history.append({"role": "assistant", "text": message[1]}) |
|
else: |
|
assert not message[1], "the bot message internal error, get: {}".format(message[1]) |
|
chat_history.append({"role": "assistant", "text": ""}) |
|
return chat_history |
|
|
|
def get_chat_images(history): |
|
images = [] |
|
for message in history: |
|
if isinstance(message[0], tuple): |
|
images.extend(message[0]) |
|
return images |
|
|
|
def bot(history): |
|
print(history) |
|
cur_messages = {"text": "", "images": []} |
|
for message in history[::-1]: |
|
if message[1]: |
|
break |
|
if isinstance(message[0], str): |
|
cur_messages["text"] = message[0] + " " + cur_messages["text"] |
|
elif isinstance(message[0], tuple): |
|
cur_messages["images"].extend(message[0]) |
|
cur_messages["text"] = cur_messages["text"].strip() |
|
cur_messages["images"] = cur_messages["images"][::-1] |
|
if not cur_messages["text"]: |
|
raise gr.Error("Please enter a message") |
|
if cur_messages['text'].count("<image>") < len(cur_messages['images']): |
|
gr.Warning("The number of images uploaded is more than the number of <image> placeholders in the text. Will automatically prepend <image> to the text.") |
|
cur_messages['text'] = "<image> "* (len(cur_messages['images']) - cur_messages['text'].count("<image>")) + cur_messages['text'] |
|
history[-1][0] = cur_messages["text"] |
|
if cur_messages['text'].count("<image>") > len(cur_messages['images']): |
|
gr.Warning("The number of images uploaded is less than the number of <image> placeholders in the text. Will automatically remove extra <image> placeholders from the text.") |
|
cur_messages['text'] = cur_messages['text'][::-1].replace("<image>"[::-1], "", cur_messages['text'].count("<image>") - len(cur_messages['images']))[::-1] |
|
history[-1][0] = cur_messages["text"] |
|
|
|
chat_history = get_chat_history(history) |
|
chat_images = get_chat_images(history) |
|
generation_kwargs = { |
|
"max_new_tokens": 4096, |
|
"temperature": 0.7, |
|
"top_p": 1.0, |
|
"do_sample": True, |
|
} |
|
print(None, chat_images, chat_history, generation_kwargs) |
|
response = generate(None, chat_images, chat_history, **generation_kwargs) |
|
|
|
for _output in response: |
|
history[-1][1] = _output |
|
time.sleep(0.05) |
|
yield history |
|
|
|
def build_demo(): |
|
with gr.Blocks() as demo: |
|
chatbot = gr.Chatbot(line_breaks=True) |
|
chat_input = gr.MultimodalTextbox(interactive=True, file_types=["image"], placeholder="Enter message or upload images. Please use <image> to indicate the position of uploaded images", show_label=True) |
|
|
|
chat_msg = chat_input.submit(add_message, [chatbot, chat_input], [chatbot, chat_input]) |
|
bot_msg = chat_msg.success(bot, chatbot, chatbot, api_name="bot_response") |
|
|
|
chatbot.like(print_like_dislike, None, None) |
|
|
|
with gr.Row(): |
|
send_button = gr.Button("Send") |
|
clear_button = gr.ClearButton([chatbot, chat_input]) |
|
|
|
send_button.click( |
|
add_message, [chatbot, chat_input], [chatbot, chat_input] |
|
).then( |
|
bot, chatbot, chatbot, api_name="bot_response" |
|
) |
|
return demo |
|
|
|
|
|
if __name__ == "__main__": |
|
demo = build_demo() |
|
demo.launch() |