Mantis / app.py
DongfuJiang's picture
update
f16e094
raw
history blame
4.98 kB
import gradio as gr
import spaces
import time
from PIL import Image
from models.mllava import MLlavaProcessor, LlavaForConditionalGeneration, chat_mllava, MLlavaForConditionalGeneration
from typing import List
processor = MLlavaProcessor.from_pretrained("MFuyu/mllava_llava_debug_nlvr2_v5_4096")
model = LlavaForConditionalGeneration.from_pretrained("MFuyu/mllava_llava_debug_nlvr2_v5_4096")
@spaces.GPU
def generate(text:str, images:List[Image.Image], history: List[dict], **kwargs):
global processor, model
model = model.to("cuda")
if not images:
images = None
for text, history in chat_mllava(text, images, model, processor, history=history, stream=True, **kwargs):
yield text
return text
def enable_next_image(uploaded_images, image):
uploaded_images.append(image)
return uploaded_images, gr.MultimodalTextbox(value=None, interactive=False)
def add_message(history, message):
if message["files"]:
for file in message["files"]:
history.append([(file,), None])
if message["text"]:
history.append([message["text"], None])
return history, gr.MultimodalTextbox(value=None)
def print_like_dislike(x: gr.LikeData):
print(x.index, x.value, x.liked)
def get_chat_history(history):
chat_history = []
for i, message in enumerate(history):
if isinstance(message[0], str):
chat_history.append({"role": "user", "text": message[0]})
if i != len(history) - 1:
assert message[1], "The bot message is not provided, internal error"
chat_history.append({"role": "assistant", "text": message[1]})
else:
assert not message[1], "the bot message internal error, get: {}".format(message[1])
chat_history.append({"role": "assistant", "text": ""})
return chat_history
def get_chat_images(history):
images = []
for message in history:
if isinstance(message[0], tuple):
images.extend(message[0])
return images
def bot(history):
print(history)
cur_messages = {"text": "", "images": []}
for message in history[::-1]:
if message[1]:
break
if isinstance(message[0], str):
cur_messages["text"] = message[0] + " " + cur_messages["text"]
elif isinstance(message[0], tuple):
cur_messages["images"].extend(message[0])
cur_messages["text"] = cur_messages["text"].strip()
cur_messages["images"] = cur_messages["images"][::-1]
if not cur_messages["text"]:
raise gr.Error("Please enter a message")
if cur_messages['text'].count("<image>") < len(cur_messages['images']):
gr.Warning("The number of images uploaded is more than the number of <image> placeholders in the text. Will automatically prepend <image> to the text.")
cur_messages['text'] = "<image> "* (len(cur_messages['images']) - cur_messages['text'].count("<image>")) + cur_messages['text']
history[-1][0] = cur_messages["text"]
if cur_messages['text'].count("<image>") > len(cur_messages['images']):
gr.Warning("The number of images uploaded is less than the number of <image> placeholders in the text. Will automatically remove extra <image> placeholders from the text.")
cur_messages['text'] = cur_messages['text'][::-1].replace("<image>"[::-1], "", cur_messages['text'].count("<image>") - len(cur_messages['images']))[::-1]
history[-1][0] = cur_messages["text"]
chat_history = get_chat_history(history)
chat_images = get_chat_images(history)
generation_kwargs = {
"max_new_tokens": 4096,
"temperature": 0.7,
"top_p": 1.0,
"do_sample": True,
}
print(None, chat_images, chat_history, generation_kwargs)
response = generate(None, chat_images, chat_history, **generation_kwargs)
for _output in response:
history[-1][1] = _output
time.sleep(0.05)
yield history
def build_demo():
with gr.Blocks() as demo:
chatbot = gr.Chatbot(line_breaks=True)
chat_input = gr.MultimodalTextbox(interactive=True, file_types=["image"], placeholder="Enter message or upload images. Please use <image> to indicate the position of uploaded images", show_label=True)
chat_msg = chat_input.submit(add_message, [chatbot, chat_input], [chatbot, chat_input])
bot_msg = chat_msg.success(bot, chatbot, chatbot, api_name="bot_response")
chatbot.like(print_like_dislike, None, None)
with gr.Row():
send_button = gr.Button("Send")
clear_button = gr.ClearButton([chatbot, chat_input])
send_button.click(
add_message, [chatbot, chat_input], [chatbot, chat_input]
).then(
bot, chatbot, chatbot, api_name="bot_response"
)
return demo
if __name__ == "__main__":
demo = build_demo()
demo.launch()