Mantis / app_idefics2.py
DongfuJiang's picture
update
a915791
raw
history blame
9.3 kB
import gradio as gr
import spaces
import time
from PIL import Image
from transformers import AutoProcessor, AutoModelForVision2Seq
from transformers.image_utils import load_image
from typing import List
processor = AutoProcessor.from_pretrained("TIGER-Lab/Mantis-8B-Idefics2")
model = AutoModelForVision2Seq.from_pretrained("TIGER-Lab/Mantis-8B-Idefics2")
@spaces.GPU
def generate_stream(text:str, images:List[Image.Image], history: List[dict], **kwargs):
global processor, model
model = model.to("cuda")
if not images:
images = None
prompt = processor.apply_chat_template(history, add_generation_prompt=True)
print("Prompt: ")
print(prompt)
print("Images: ")
print(images)
inputs = processor(text=prompt, images=images, return_tensors="pt")
inputs = {k: v.to(model.device) for k, v in inputs.items()}
from transformers import TextIteratorStreamer
from threading import Thread
streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
kwargs["streamer"] = streamer
inputs.update(kwargs)
thread = Thread(target=model.generate, kwargs=inputs)
thread.start()
output = ""
for _output in streamer:
output += _output
yield output
def enable_next_image(uploaded_images, image):
uploaded_images.append(image)
return uploaded_images, gr.MultimodalTextbox(value=None, interactive=False)
def add_message(history, message):
if message["files"]:
for file in message["files"]:
history.append([(file,), None])
if message["text"]:
history.append([message["text"], None])
return history, gr.MultimodalTextbox(value=None)
def print_like_dislike(x: gr.LikeData):
print(x.index, x.value, x.liked)
def get_chat_images(history):
images = []
for message in history:
if isinstance(message[0], tuple):
image = load_image(message[0][0])
images.append(image)
return images
def get_chat_history(history):
images = get_chat_images(history)
messages = []
cur_image_idx = 0
for i, message in enumerate(history):
if isinstance(message[0], str):
num_images = message[0].count("<image>")
messages.append(
{
"role": "user",
"content": []
}
)
assert num_images + cur_image_idx <= len(images), f"Number of images uploaded is less than the number of <image> placeholders in the text. Please upload more images."
if num_images > 0:
for sub_text in message[0].split("<image>"):
if sub_text.strip():
messages[-1]["content"].append({"type": "text", "text": sub_text.strip()})
if cur_image_idx < len(images):
messages[-1]["content"].append({"type": "image"})
cur_image_idx += 1
else:
messages[-1]["content"].append({"type": "text", "text": message[0]})
elif isinstance(message[0], tuple):
pass
return messages, images
def bot(history):
cur_messages = {"text": "", "images": []}
for message in history[::-1]:
if message[1]:
break
if isinstance(message[0], str):
cur_messages["text"] = message[0] + " " + cur_messages["text"]
elif isinstance(message[0], tuple):
cur_messages["images"].extend(message[0])
cur_messages["text"] = cur_messages["text"].strip()
cur_messages["images"] = cur_messages["images"][::-1]
if not cur_messages["text"]:
raise gr.Error("Please enter a message")
if cur_messages['text'].count("<image>") < len(cur_messages['images']):
gr.Warning("The number of images uploaded is more than the number of <image> placeholders in the text. Will automatically prepend <image> to the text.")
cur_messages['text'] = "<image> "* (len(cur_messages['images']) - cur_messages['text'].count("<image>")) + cur_messages['text']
history[-1][0] = cur_messages["text"]
if cur_messages['text'].count("<image>") > len(cur_messages['images']):
gr.Warning("The number of images uploaded is less than the number of <image> placeholders in the text. Will automatically remove extra <image> placeholders from the text.")
cur_messages['text'] = cur_messages['text'][::-1].replace("<image>"[::-1], "", cur_messages['text'].count("<image>") - len(cur_messages['images']))[::-1]
history[-1][0] = cur_messages["text"]
chat_history, chat_images = get_chat_history(history)
generation_kwargs = {
"max_new_tokens": 4096,
"num_beams": 1,
"do_sample": False
}
response = generate_stream(None, chat_images, chat_history, **generation_kwargs)
for _output in response:
history[-1][1] = _output
time.sleep(0.05)
yield history
def build_demo():
with gr.Blocks() as demo:
gr.Markdown(""" # Mantis
Mantis is a multimodal conversational AI model that can chat with users about images and text. It's optimized for multi-image reasoning, where inverleaved text and images can be used to generate responses.
### [Paper](https://arxiv.org/abs/2405.01483) | [Github](https://github.com/TIGER-AI-Lab/Mantis) | [Models](https://huggingface.co/collections/TIGER-Lab/mantis-6619b0834594c878cdb1d6e4) | [Dataset](https://huggingface.co/datasets/TIGER-Lab/Mantis-Instruct) | [Website](https://tiger-ai-lab.github.io/Mantis/)
""")
gr.Markdown("""## Chat with Mantis
Mantis supports interleaved text-image input format, where you can simply use the placeholder `<image>` to indicate the position of uploaded images.
The model is optimized for multi-image reasoning, while preserving the ability to chat about text and images in a single conversation.
(The model currently serving is [🤗 TIGER-Lab/Mantis-8B-Idefics2](https://huggingface.co/TIGER-Lab/Mantis-8B-Idefics2))
""")
chatbot = gr.Chatbot(line_breaks=True)
chat_input = gr.MultimodalTextbox(interactive=True, file_types=["image"], placeholder="Enter message or upload images. Please use <image> to indicate the position of uploaded images", show_label=True)
chat_msg = chat_input.submit(add_message, [chatbot, chat_input], [chatbot, chat_input])
"""
with gr.Accordion(label='Advanced options', open=False):
temperature = gr.Slider(
label='Temperature',
minimum=0.1,
maximum=2.0,
step=0.1,
value=0.2,
interactive=True
)
top_p = gr.Slider(
label='Top-p',
minimum=0.05,
maximum=1.0,
step=0.05,
value=1.0,
interactive=True
)
"""
bot_msg = chat_msg.success(bot, chatbot, chatbot, api_name="bot_response")
chatbot.like(print_like_dislike, None, None)
with gr.Row():
send_button = gr.Button("Send")
clear_button = gr.ClearButton([chatbot, chat_input])
send_button.click(
add_message, [chatbot, chat_input], [chatbot, chat_input]
).then(
bot, chatbot, chatbot, api_name="bot_response"
)
gr.Examples(
examples=[
{
"text": "<image> <image> <image> Which image shows a different mood of character from the others?",
"files": ["./examples/image12.jpg", "./examples/image13.jpg", "./examples/image14.jpg"]
},
{
"text": "<image> <image> What's the difference between these two images? Please describe as much as you can.",
"files": ["./examples/image1.jpg", "./examples/image2.jpg"]
},
{
"text": "<image> <image> Which image shows an older dog?",
"files": ["./examples/image8.jpg", "./examples/image9.jpg"]
},
{
"text": "Write a description for the given image sequence in a single paragraph, what is happening in this episode?",
"files": ["./examples/image3.jpg", "./examples/image4.jpg", "./examples/image5.jpg", "./examples/image6.jpg", "./examples/image7.jpg"]
},
{
"text": "<image> <image> How many dices are there in image 1 and image 2 respectively?",
"files": ["./examples/image10.jpg", "./examples/image15.jpg"]
},
],
inputs=[chat_input],
)
gr.Markdown("""
## Citation
```
@article{jiang2024mantis,
title={MANTIS: Interleaved Multi-Image Instruction Tuning},
author={Jiang, Dongfu and He, Xuan and Zeng, Huaye and Wei, Con and Ku, Max and Liu, Qian and Chen, Wenhu},
journal={arXiv preprint arXiv:2405.01483},
year={2024}
}
```""")
return demo
if __name__ == "__main__":
demo = build_demo()
demo.launch()