import gradio as gr from gradio.utils import async_lambda import spaces import time import subprocess import torch from models.mllava import ( MLlavaProcessor, LlavaForConditionalGeneration, prepare_inputs, ) from models.conversation import Conversation, SeparatorStyle from transformers import TextIteratorStreamer from transformers.utils import is_flash_attn_2_available from threading import Thread device = "cuda" if torch.cuda.is_available() else "cpu" IMAGE_TOKEN = "" generation_kwargs = { "max_new_tokens": 128, "num_beams": 1, "do_sample": False, "no_repeat_ngram_size": 3, } if device == "cpu": processor = None model = None else: if not is_flash_attn_2_available(): subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True) processor = MLlavaProcessor.from_pretrained("TIGER-Lab/Mantis-8B-siglip-llama3") processor.tokenizer.pad_token = processor.tokenizer.eos_token model = LlavaForConditionalGeneration.from_pretrained( "HODACHI/Llama-3-EZO-VLM-1", torch_dtype=torch.float16, attn_implementation="flash_attention_2", device_map=device, ).eval() # Set the system prompt conv_template = Conversation( system="<|start_header_id|>system<|end_header_id|>\n\nあなたは誠実で優秀な日本人のアシスタントです。特に指示が無い場合は、常に日本語で回答してください。", roles=("user", "assistant"), messages=(), offset=0, sep_style=SeparatorStyle.LLAMA_3, sep="<|eot_id|>", ) def get_chat_messages(history): chat_history = [] user_role = conv_template.roles[0] assistant_role = conv_template.roles[1] for i, message in enumerate(history): if isinstance(message[0], str): chat_history.append({"role": user_role, "text": message[0]}) if i != len(history) - 1: assert message[1], "The bot message is not provided, internal error" chat_history.append({"role": assistant_role, "text": message[1]}) else: assert not message[1], "the bot message internal error, get: {}".format( message[1] ) chat_history.append({"role": assistant_role, "text": ""}) return chat_history def get_chat_images(history): images = [] for message in history: if isinstance(message[0], tuple): images.extend(message[0]) return images @spaces.GPU def bot(message, history): if not model: print(message, history) images = message["files"] if message["files"] else None text = message["text"].strip() if not text: raise gr.Error("You must enter a message!") num_image_tokens = text.count(IMAGE_TOKEN) # modify text if images and num_image_tokens < len(images): if num_image_tokens != 0: gr.Warning( "The number of images uploaded is more than the number of placeholders in the text. Will automatically prepend to the text." ) # prefix image tokens text = IMAGE_TOKEN * (len(images) - num_image_tokens) + text if images and num_image_tokens > len(images): raise gr.Error( "The number of images uploaded is less than the number of placeholders in the text!" ) current_messages = [] if images: current_messages += [[(image,), None] for image in images] if text: current_messages += [[text, None]] current_history = history + current_messages chat_messages = get_chat_messages(current_history) chat_images = get_chat_images(current_history) # Generate! inputs = prepare_inputs(None, chat_images, model, processor, history=chat_messages, **generation_kwargs) streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True) inputs["streamer"] = streamer thread = Thread(target=model.generate, kwargs=inputs) thread.start() buffer = "" for new_text in streamer: buffer += new_text time.sleep(0.01) yield buffer DESCRIPTION = """# 🐟 Llama-3-EZO-VLM-1 🤗 [モデル一覧](https://huggingface.co/HODACHI) | [Llama-3-EZO-VLM-1](https://huggingface.co/HODACHI/Llama-3-EZO-VLM-1)は[Axcxept co., ltd.](https://axcxept.com/)が [SakanaAI/Llama-3-EvoVLM-JP-v2](https://huggingface.co/SakanaAI/Llama-3-EvoVLM-JP-v2)をベースに性能向上を行った視覚言語モデルです。 """ examples = [ { "text": "1番目と2番目の画像に写っている動物の違いは何ですか?簡潔に説明してください。", "files": ["./examples/image_0.jpg", "./examples/image_1.jpg"], }, { "text": "2枚の写真について、簡単にそれぞれ説明してください。", "files": ["./examples/image_2.jpg", "./examples/image_3.jpg"], }, ] chat = gr.ChatInterface( fn=bot, multimodal=True, chatbot=gr.Chatbot(label="Chatbot", scale=1, height=500), textbox=gr.MultimodalTextbox( interactive=True, file_types=["image"], # file_count="multiple", placeholder="Enter message or upload images. Please use to indicate the position of uploaded images", show_label=True, ), examples=examples, fill_height=False, stop_btn=None, ) with gr.Blocks(fill_height=True) as demo: gr.Markdown(DESCRIPTION) chat.render() chat.examples_handler.load_input_event.then( fn=async_lambda(lambda: [[], [], None]), outputs=[chat.chatbot, chat.chatbot_state, chat.saved_input], ) gr.Markdown( """ ### チャットの方法 HODACHI/Llama-3-EZO-VLM-1は、画像をテキストの好きな場所に入力として配置することができます。画像をアップロードする場所は、``というフレーズで指定できます。 モデルの推論時に、自動的に``が画像トークンに置き換えられます。また、画像のアップロード数が``の数よりも少ない場合、余分な``が削除されます。 逆に、画像のアップロード数が``の数よりも多い場合、自動的に``が追加されます。 ### 注意事項 本モデルは実験段階のプロトタイプであり、研究開発の目的でのみ提供されています。商用利用や、障害が重大な影響を及ぼす可能性のある環境(ミッションクリティカルな環境)での使用には適していません。 本モデルの使用は、利用者の自己責任で行われ、その性能や結果については何ら保証されません。 Axcxept co., ltd.は、本モデルの使用によって生じた直接的または間接的な損失に対して、結果に関わらず、一切の責任を負いません。 利用者は、本モデルの使用に伴うリスクを十分に理解し、自身の判断で使用することが必要です。 また、このデモでは、できる限り多くの皆様にお使いいただけるように、出力テキストのサイズを制限しております。""" ) demo.queue().launch()